Skip to content

UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::scatter_reduce.two. #134797

@aboubezari

Description

@aboubezari

🐛 Describe the bug

Note: I know my version is on 2.1.0, but I have also tried on 2.3.0 and the issue is still there.

UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::scatter_reduce.two. 
Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at ../aten/src/ATen/functorch/BatchedFallback.cpp:81.)                                                                                                                                                                        
  scattered_1d = torch.scatter_reduce(input=tensor_1d,                                                                                                    

You can reproduce it with any vmap function with torch.scatter_reduce.

import torch

src = torch.tensor([1., 2., 3., 4., 5., 6.])
index = torch.tensor([1, 1, 0, 1, 2, 1])
input = torch.tensor([1., 2., 3., 4.])

# Simulate a batch dimension of 1
src = src.unsqueeze(0)
index = index.unsqueeze(0)
input = input.unsqueeze(0)

def _fn(inputs):
    _src, _index, _input = inputs
    return torch.scatter_reduce(_input, 0, _index, _src, reduce="sum")

result = torch.vmap(_fn)((src, index, input))
print(result)

Versions

PyTorch version: 2.1.2
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.27.7
Libc version: glibc-2.31

Python version: 3.11.4 (main, Jul 5 2023, 13:45:01) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.8.0-59-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA RTX A4000
Nvidia driver version: 535.183.01
cuDNN version: Probably one of the following:
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-12.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 20
On-line CPU(s) list: 0-19
Thread(s) per core: 2
Core(s) per socket: 10
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 79
Model name: Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz
Stepping: 1
CPU MHz: 2145.402
CPU max MHz: 3100.0000
CPU min MHz: 1200.0000
BogoMIPS: 4399.76
Virtualization: VT-x
L1d cache: 320 KiB
L1i cache: 320 KiB
L2 cache: 2.5 MiB
L3 cache: 25 MiB
NUMA node0 CPU(s): 0-19
Vulnerability Itlb multihit: KVM: Mitigation: VMX disabled
Vulnerability L1tf: Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Full generic retpoline, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Mitigation; Clear CPU buffers; SMT vulnerable
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single pti intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm ida arat pln pts md_clear flush_l1d

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] onnx==1.15.0
[pip3] onnx-graphsurgeon==0.3.27
[pip3] onnxruntime==1.17.0
[pip3] optree==0.11.0
[pip3] torch==2.1.2
[pip3] torchaudio==2.1.2
[pip3] torchvision==0.16.2
[pip3] triton==2.1.0
[conda] blas 1.0 mkl
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libjpeg-turbo 2.0.0 h9bf148f_0 pytorch
[conda] mkl 2023.1.0 h213fc3f_46344
[conda] mkl-service 2.4.0 py311h5eee18b_1
[conda] mkl_fft 1.3.8 py311h5eee18b_0
[conda] mkl_random 1.2.4 py311hdb19cb5_0
[conda] numpy 1.26.3 py311h08b1b3b_0
[conda] numpy-base 1.26.3 py311hf175353_0
[conda] optree 0.11.0 pypi_0 pypi
[conda] pytorch 2.1.2 py3.11_cuda11.8_cudnn8.7.0_0 pytorch
[conda] pytorch-cuda 11.8 h7e8668a_5 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchaudio 2.1.2 py311_cu118 pytorch
[conda] torchtriton 2.1.0 py311 pytorch
[conda] torchvision 0.16.2 py311_cu118 pytorch

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @Chillee @samdow @kshitij12345

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: functorchPertaining to torch.func or pytorch/functorchmodule: vmaptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions