Skip to content

torch.linalg.norm and torch.norm give wrong result with torch.complex32 type tensor #132634

@MezoBlast

Description

@MezoBlast

🐛 Describe the bug

torch.norm and torch.linalg.norm funcitons gives wrong result with torch.complex32 typed tensor.
The result is correct for torch.complex64 and torch.complex128 type.
The problem exists for p=2 and p=1 norm as I have tested.

A direct implementation of 2-norm can get right answer even with torch.complex32 result though.

import torch

# example
complex_tensor = [1.0+0.0j, 0.0+1.0j] # 2-norm should approximately be 1.4142, 1-norm should be 2
x = torch.tensor(complex_tensor, dtype=torch.complex32)
y = torch.tensor(complex_tensor, dtype=torch.complex64)

x_norm_0 = torch.norm(x) # result is 1, wrong
x_norm_1 = torch.linalg.norm(x) # result is 1, wrong
x_norm_2 = torch.norm(x, p=1) # result is 1, wrong
x_norm_3 = torch.linalg.norm(x, ord=1) # result is 1, wrong

x_norm_4 = torch.sum(torch.abs(x) ** 2) ** 0.5 # result is 1.4142, correct

y_norm_0 = torch.norm(y) # result is 1.4142, correct
y_norm_1 = torch.linalg.norm(y) # result is 1.4142, correct
y_norm_2 = torch.norm(y, p=1) # result is 2, correct
y_norm_3 = torch.linalg.norm(y, ord=2) # result is 2, correct

y_norm_4 = torch.sum(torch.abs(y) ** 2) ** 0.5 # result is 1.4142, correct

### Versions

Collecting environment information...
PyTorch version: 2.2.1
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.35

Python version: 3.10.9 | packaged by conda-forge | (main, Feb  2 2023, 20:20:04) [GCC 11.3.0] (64-bit runtime)
Python platform: Linux-6.5.0-45-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 550.54.14
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 48 bits virtual
Byte Order:                         Little Endian
CPU(s):                             32
On-line CPU(s) list:                0-31
Vendor ID:                          GenuineIntel
Model name:                         13th Gen Intel(R) Core(TM) i9-13900K
CPU family:                         6
Model:                              183
Thread(s) per core:                 2
Core(s) per socket:                 24
Socket(s):                          1
Stepping:                           1
CPU max MHz:                        5800.0000
CPU min MHz:                        800.0000
BogoMIPS:                           5990.40
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi vnmi umip pku ospke waitpkg gfni vaes vpclmulqdq tme rdpid movdiri movdir64b fsrm md_clear serialize pconfig arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                     VT-x
L1d cache:                          896 KiB (24 instances)
L1i cache:                          1.3 MiB (24 instances)
L2 cache:                           32 MiB (12 instances)
L3 cache:                           36 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-31
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI BHI_DIS_S
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] optree==0.10.0
[pip3] pytorch-msssim==0.2.1
[pip3] torch==2.2.1
[pip3] torchaudio==2.2.1
[pip3] torchelastic==0.2.2
[pip3] torchvision==0.17.1
[pip3] triton==2.2.0
[conda] blas                      1.0                         mkl  
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mkl-service               2.4.0           py310h5eee18b_1  
[conda] mkl_fft                   1.3.8           py310h5eee18b_0  
[conda] mkl_random                1.2.4           py310hdb19cb5_0  
[conda] numpy                     1.26.3          py310h5f9d8c6_0  
[conda] numpy-base                1.26.3          py310hb5e798b_0  
[conda] optree                    0.10.0                   pypi_0    pypi
[conda] pytorch                   2.2.1           py3.10_cuda12.1_cudnn8.9.2_0    pytorch
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-msssim            0.2.1                    pypi_0    pypi
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] torchaudio                2.2.1               py310_cu121    pytorch
[conda] torchelastic              0.2.2                    pypi_0    pypi
[conda] torchtriton               2.2.0                     py310    pytorch
[conda] torchvision               0.17.1              py310_cu121    pytorch


cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @anjali411 @dylanbespalko @mruberry @Lezcano @nikitaved @amjames @jianyuh @pearu @walterddr @xwang233

Metadata

Metadata

Assignees

Labels

high prioritymodule: complexRelated to complex number support in PyTorchmodule: correctness (silent)issue that returns an incorrect result silentlymodule: halfRelated to float16 half-precision floatsmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions