Hi NV team! Thanks for sharing this work!
I noticed that the to_device() function in helper.py currently changes all tensors to whatever dtype given. Some token tensors like input_ids and attention_mask are changing when they typically stay as integers for Hugging Face models. It’s causing errors if you try to use mixed precision, and the fields get turned into floats and break embedding lookups. Is it possible to only cast tensors to the dtype if they're already floating-points by checking with is_floating_point() first? This way it preserves critical integer/bool fields while still allowing correct mixed precision inference for other tensors.
def to_device(data, device=None, dtype=None):
if isinstance(data, torch.Tensor):
if dtype is not None and data.is_floating_point():
return data.to(device=device, dtype=dtype)
return data.to(device=device)
elif isinstance(data, collections.abc.Mapping):
return {k: to_device(v, device=device, dtype=dtype) for k, v in data.items()}
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
return [to_device(x, device=device, dtype=dtype) for x in data]
else:
return data
Hi NV team! Thanks for sharing this work!
I noticed that the to_device() function in helper.py currently changes all tensors to whatever dtype given. Some token tensors like input_ids and attention_mask are changing when they typically stay as integers for Hugging Face models. It’s causing errors if you try to use mixed precision, and the fields get turned into floats and break embedding lookups. Is it possible to only cast tensors to the dtype if they're already floating-points by checking with is_floating_point() first? This way it preserves critical integer/bool fields while still allowing correct mixed precision inference for other tensors.