Skip to content

torch.sum promotes integral tensors to int64. #82159

@nikitaved

Description

@nikitaved

🐛 Describe the bug

Repro:

In [1]: import torch
   ...: dtypes = [torch.bool, torch.int8, torch.int32, torch.bfloat16, torch.float32, torch.float64]
   ...: for dtype in dtypes:
   ...:     a = torch.tensor([], dtype=dtype)
   ...:     a_sum = a.sum()
   ...:     if a.dtype != a_sum.dtype:
   ...:         print(f"t.dtype != t.sum().dtype, got {a.dtype} != {a_sum.dtype}")
   ...: 
t.dtype != t.sum().dtype, got torch.bool != torch.int64
t.dtype != t.sum().dtype, got torch.int8 != torch.int64
t.dtype != t.sum().dtype, got torch.int32 != torch.int64

This is the cause of #82150.

Versions

Current master.

cc @svekars @carljparker @nairbv @mruberry @holly1238

Metadata

Metadata

Assignees

No one assigned

    Labels

    actionablemodule: docsRelated to our documentation, both in docs/ and docblocksmodule: reductionsmodule: type promotionRelated to semantics of type promotiontriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions