Skip to content

Comments

[torchlib] Fix linspace implementation for int64#2693

Merged
justinchuby merged 15 commits intomicrosoft:mainfrom
Aravind-11:fix-linspace-int64
Jan 8, 2026
Merged

[torchlib] Fix linspace implementation for int64#2693
justinchuby merged 15 commits intomicrosoft:mainfrom
Aravind-11:fix-linspace-int64

Conversation

@Aravind-11
Copy link
Contributor

Description

Fixes #854 - linspace now correctly handles int64 dtype

Changes

  • Modified aten_linspace to compute in floating-point then cast to target dtype
  • This matches PyTorch's behavior and fixes integer division precision loss

Testing

Manually verified: linspace(0, 10, 5, dtype=int64) now produces correct output [0, 2, 5, 7, 10]

Questions

Where should I add automated test cases for this fix? Happy to add tests wherever you suggest!

@Aravind-11
Copy link
Contributor Author

Description

Fixes #854 - linspace now correctly handles int64 dtype

Changes

  • Modified aten_linspace to compute in floating-point then cast to target dtype
  • This matches PyTorch's behavior and fixes integer division precision loss

Testing

Manually verified: linspace(0, 10, 5, dtype=int64) now produces correct output [0, 2, 5, 7, 10]

Questions

Where should I add automated test cases for this fix? Happy to add tests wherever you suggest!

Who can review : @justinchuby

@codecov
Copy link

codecov bot commented Nov 16, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 70.20%. Comparing base (6e91205) to head (ca6da50).
⚠️ Report is 5 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2693      +/-   ##
==========================================
+ Coverage   70.11%   70.20%   +0.08%     
==========================================
  Files         228      228              
  Lines       27396    27430      +34     
  Branches     2785     2787       +2     
==========================================
+ Hits        19208    19256      +48     
+ Misses       7228     7216      -12     
+ Partials      960      958       -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@justinchuby
Copy link
Collaborator

Thanks. Could you unskip the tests:

@Aravind-11
Copy link
Contributor Author

Thanks. Could you unskip the tests:

Thank you for reviewing! Done.

@Aravind-11
Copy link
Contributor Author

Thanks. Could you unskip the tests:

Thank you for reviewing! Done.

hi @justinchuby , are there any updates on this that you could let me know. Thank you!

@github-project-automation github-project-automation bot moved this from Todo to Done in ONNX Script Review Board Dec 9, 2025
@Aravind-11
Copy link
Contributor Author

Aravind-11 commented Dec 9, 2025

Hi @justinchuby , the CI is failing and I updated the branch..could you approve the workflows for testing again? Thank you!

@Aravind-11
Copy link
Contributor Author

Hi @justinchuby , CUDA tests fail because PyTorch itself gives different results on CPU vs CUDA for integer linspace. For example, torch.linspace(4.3, -3, 50, dtype=torch.int64) returns different values at certain indices on CPU vs CUDA.

@justinchuby
Copy link
Collaborator

Thanks. In CI we only run cpu tests so we should be ok.

@Aravind-11
Copy link
Contributor Author

Thanks. In CI we only run cpu tests so we should be ok.

Got it. Thank you. Could you please approve the tests ?

@justinchuby justinchuby changed the title Fixes #854 [torchlib] Fix linspace implementation for int64 Dec 11, 2025
@justinchuby justinchuby added the module: torchlib Related to the torch/aten function lib in development label Dec 11, 2025
@justinchuby justinchuby self-assigned this Dec 11, 2025
@Aravind-11
Copy link
Contributor Author

Thank you for the approval. Let me know if anything else is needed from my side.

@justinchuby
Copy link
Collaborator

I just realized you are using double precision. Does float32 work? Or is float64 required?

@Aravind-11
Copy link
Contributor Author

I tested both precisions, and float64 is required for correctness.
With float32, we get precision errors at certain indices:

# Index 21: float32 gives 0.999999761... → truncates to 0
#           float64 gives 1.000000000000... → truncates to 1 

@Aravind-11
Copy link
Contributor Author

this is the code i used to test

import torch
import numpy as np

def test_precision(start, end, steps, dtype_name):
    print(f"\n{'='*60}")
    print(f"linspace({start}, {end}, {steps}) with {dtype_name}")
    print(f"{'='*60}")
    
    start_int, end_int = int(start), int(end)
    
    
    step_f32 = np.float32((end_int - start_int) / (steps - 1))
    indices_f32 = np.arange(steps, dtype=np.float32)
    forward_f32 = start_int + step_f32 * indices_f32
    backward_f32 = end_int - step_f32 * (steps - 1 - indices_f32)
    result_f32 = np.where(indices_f32 < steps/2, forward_f32, backward_f32).astype(np.int64)
    
    step_f64 = np.float64((end_int - start_int) / (steps - 1))
    indices_f64 = np.arange(steps, dtype=np.float64)
    forward_f64 = start_int + step_f64 * indices_f64
    backward_f64 = end_int - step_f64 * (steps - 1 - indices_f64)
    result_f64 = np.where(indices_f64 < steps/2, forward_f64, backward_f64).astype(np.int64)
    
    torch_result = torch.linspace(start, end, steps, dtype=torch.int64).numpy()
    
    match_f32 = np.array_equal(result_f32, torch_result)
    match_f64 = np.array_equal(result_f64, torch_result)
    
    print(f"Float32 matches PyTorch: {match_f32}")
    print(f"Float64 matches PyTorch: {match_f64}")
    
    if not match_f32:
        diff_indices = np.where(result_f32 != torch_result)[0]
        print(f"\nFloat32 differences at {len(diff_indices)} indices: {diff_indices[:10]}")
        for idx in diff_indices[:3]:
            print(f"  Index {idx}: f32={result_f32[idx]}, f64={result_f64[idx]}, pytorch={torch_result[idx]}")
            print(f"    f32_float={forward_f32[idx] if idx < steps/2 else backward_f32[idx]:.15f}")
            print(f"    f64_float={forward_f64[idx] if idx < steps/2 else backward_f64[idx]:.15f}")

test_precision(4.3, -3, 50, "int64")
test_precision(0, 7, 50, "int64")
test_precision(50, 0, 50, "int64")

@justinchuby
Copy link
Collaborator

Thanks

@Aravind-11
Copy link
Contributor Author

uh oh .. looks like we re back to square one 😄 , please let me know what to do @justinchuby

@justinchuby
Copy link
Collaborator

Thanks, I will revert some of my changes

@Aravind-11
Copy link
Contributor Author

Got it. Thank you.

@Aravind-11
Copy link
Contributor Author

Hi!, anything blocking merge ? lmk if i need to change anything. thank you!

@justinchuby
Copy link
Collaborator

No, I need to spend some time to fix the commits

@Aravind-11
Copy link
Contributor Author

No, I need to spend some time to fix the commits

got it. thank you 😄

@justinchuby
Copy link
Collaborator

oh I need to fix some typing issues first

@Aravind-11
Copy link
Contributor Author

oh I need to fix some typing issues first

thank you, i'll see if i can debug it after this ?

@justinchuby
Copy link
Collaborator

@Aravind-11 I fixed the type issues. Looks like there are a few int64/int32 tests that are not passing. Could you help with those? Thanks a lot!

@Aravind-11
Copy link
Contributor Author

@Aravind-11 I fixed the type issues. Looks like there are a few int64/int32 tests that are not passing. Could you help with those? Thanks a lot!

got it. i'll take a look! :))

@Aravind-11
Copy link
Contributor Author

@Aravind-11 I fixed the type issues. Looks like there are a few int64/int32 tests that are not passing. Could you help with those? Thanks a lot!

got it. i'll take a look! :))

@justinchuby please take a look, thank you :)

Ensure double precision is used for computation to match PyTorch's internal precision.
@justinchuby justinchuby enabled auto-merge (squash) January 8, 2026 02:35
@justinchuby justinchuby merged commit 5a338ad into microsoft:main Jan 8, 2026
77 of 80 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: torchlib Related to the torch/aten function lib in development

Projects

Development

Successfully merging this pull request may close these issues.

[torchlib] linspace results do not match with PyTorch when dtype is int64

3 participants