[torchlib] Fix linspace implementation for int64#2693
[torchlib] Fix linspace implementation for int64#2693justinchuby merged 15 commits intomicrosoft:mainfrom
Conversation
Who can review : @justinchuby |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #2693 +/- ##
==========================================
+ Coverage 70.11% 70.20% +0.08%
==========================================
Files 228 228
Lines 27396 27430 +34
Branches 2785 2787 +2
==========================================
+ Hits 19208 19256 +48
+ Misses 7228 7216 -12
+ Partials 960 958 -2 ☔ View full report in Codecov by Sentry. |
|
Thanks. Could you unskip the tests: |
Thank you for reviewing! Done. |
hi @justinchuby , are there any updates on this that you could let me know. Thank you! |
|
Hi @justinchuby , the CI is failing and I updated the branch..could you approve the workflows for testing again? Thank you! |
|
Hi @justinchuby , CUDA tests fail because PyTorch itself gives different results on CPU vs CUDA for integer linspace. For example, |
|
Thanks. In CI we only run cpu tests so we should be ok. |
Got it. Thank you. Could you please approve the tests ? |
|
Thank you for the approval. Let me know if anything else is needed from my side. |
|
I just realized you are using double precision. Does float32 work? Or is float64 required? |
|
I tested both precisions, and float64 is required for correctness. # Index 21: float32 gives 0.999999761... → truncates to 0
# float64 gives 1.000000000000... → truncates to 1 |
|
this is the code i used to test import torch
import numpy as np
def test_precision(start, end, steps, dtype_name):
print(f"\n{'='*60}")
print(f"linspace({start}, {end}, {steps}) with {dtype_name}")
print(f"{'='*60}")
start_int, end_int = int(start), int(end)
step_f32 = np.float32((end_int - start_int) / (steps - 1))
indices_f32 = np.arange(steps, dtype=np.float32)
forward_f32 = start_int + step_f32 * indices_f32
backward_f32 = end_int - step_f32 * (steps - 1 - indices_f32)
result_f32 = np.where(indices_f32 < steps/2, forward_f32, backward_f32).astype(np.int64)
step_f64 = np.float64((end_int - start_int) / (steps - 1))
indices_f64 = np.arange(steps, dtype=np.float64)
forward_f64 = start_int + step_f64 * indices_f64
backward_f64 = end_int - step_f64 * (steps - 1 - indices_f64)
result_f64 = np.where(indices_f64 < steps/2, forward_f64, backward_f64).astype(np.int64)
torch_result = torch.linspace(start, end, steps, dtype=torch.int64).numpy()
match_f32 = np.array_equal(result_f32, torch_result)
match_f64 = np.array_equal(result_f64, torch_result)
print(f"Float32 matches PyTorch: {match_f32}")
print(f"Float64 matches PyTorch: {match_f64}")
if not match_f32:
diff_indices = np.where(result_f32 != torch_result)[0]
print(f"\nFloat32 differences at {len(diff_indices)} indices: {diff_indices[:10]}")
for idx in diff_indices[:3]:
print(f" Index {idx}: f32={result_f32[idx]}, f64={result_f64[idx]}, pytorch={torch_result[idx]}")
print(f" f32_float={forward_f32[idx] if idx < steps/2 else backward_f32[idx]:.15f}")
print(f" f64_float={forward_f64[idx] if idx < steps/2 else backward_f64[idx]:.15f}")
test_precision(4.3, -3, 50, "int64")
test_precision(0, 7, 50, "int64")
test_precision(50, 0, 50, "int64") |
|
Thanks |
|
uh oh .. looks like we re back to square one 😄 , please let me know what to do @justinchuby |
|
Thanks, I will revert some of my changes |
|
Got it. Thank you. |
|
Hi!, anything blocking merge ? lmk if i need to change anything. thank you! |
|
No, I need to spend some time to fix the commits |
got it. thank you 😄 |
|
oh I need to fix some typing issues first |
thank you, i'll see if i can debug it after this ? |
|
@Aravind-11 I fixed the type issues. Looks like there are a few int64/int32 tests that are not passing. Could you help with those? Thanks a lot! |
got it. i'll take a look! :)) |
@justinchuby please take a look, thank you :) |
Ensure double precision is used for computation to match PyTorch's internal precision.
Description
Fixes #854 - linspace now correctly handles int64 dtype
Changes
aten_linspaceto compute in floating-point then cast to target dtypeTesting
Manually verified:
linspace(0, 10, 5, dtype=int64)now produces correct output[0, 2, 5, 7, 10]Questions
Where should I add automated test cases for this fix? Happy to add tests wherever you suggest!