-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
good first issuemodule: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlymodule: error checkingBugs related to incorrect/lacking error checkingBugs related to incorrect/lacking error checkingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
masked_scatter fails the backward gradient computation due to some strange error that RuntimeError: masked_select: expected BoolTensor or ByteTensor for mask. But it will not raise such error in the forward computation for the tensors requiring grad. Not sure which phase is buggy, but at least they should be consistent
import torch
from torch.autograd.functional import jacobian
from torch.func import jacrev, jacfwd
torch.manual_seed(420)
input_tensor = torch.ones(1, 3)
mask = torch.ones(1, 3)
tensor = torch.ones(3, 4)
def func(input_tensor, mask, tensor):
output_tensor = torch.masked_scatter(input_tensor, mask, tensor)
return output_tensor
func(input_tensor, mask, tensor)
# succeed
func(input_tensor.clone().requires_grad_(), mask.clone().requires_grad_(), tensor.clone().requires_grad_())
# succeed
jacobian(func, (input_tensor, mask, tensor), vectorize=True, strategy="reverse-mode")
# RuntimeError: masked_select: expected BoolTensor or ByteTensor for mask
jacrev(func)(input_tensor, mask, tensor)
# RuntimeError: masked_select: expected BoolTensor or ByteTensor for maskVersions
PyTorch version: 2.0.0.dev20230105
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.1 LTS (x86_64)
GCC version: (Ubuntu 11.3.0-1ubuntu1~22.04) 11.3.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.9.15 (main, Nov 24 2022, 14:31:59) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.7.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
Nvidia driver version: 515.86.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.4.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.4.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] torch==2.0.0.dev20230105
[pip3] torchaudio==2.0.0.dev20230105
[pip3] torchvision==0.15.0.dev20230105
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] numpy 1.23.5 py39h14f4228_0
[conda] numpy-base 1.23.5 py39h31eccc5_0
[conda] pytorch 2.0.0.dev20230105 py3.9_cuda11.7_cudnn8.5.0_0 pytorch-nightly
[conda] pytorch-cuda 11.7 h67b0de4_2 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 2.0.0.dev20230105 py39_cu117 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 py39 pytorch-nightly
[conda] torchvision 0.15.0.dev20230105 py39_cu117 pytorch-nightly
Metadata
Metadata
Assignees
Labels
good first issuemodule: correctness (silent)issue that returns an incorrect result silentlyissue that returns an incorrect result silentlymodule: error checkingBugs related to incorrect/lacking error checkingBugs related to incorrect/lacking error checkingtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module