-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Bug
When using torch.nn.functional.nll_loss with 16-bit CUDA tensors the default reduction='mean' produces NaNs. However doing the reduction manually gives accurate results at (arguably) minimal latency cost. I didn't see this in the nll_loss documentation so it came as a surprise to me.
I haven't been able to step through the C++ implementation to tell if this is actually a bug or just an undocumented issue w.r.t the algorithms that are being used for computing this with half precision tensors on GPUs.
To Reproduce
import torch
import torch.nn.functional as F
from timeit import default_timer as timer
def _compare_cuda(log_probs: torch.Tensor, targets: torch.Tensor):
# Compute and time manual reduction
start_manual = torch.cuda.Event(enable_timing=True)
end_manual = torch.cuda.Event(enable_timing=True)
start_manual.record()
loss_manual = F.nll_loss(
input=log_probs,
target=targets,
reduction='none'
)
loss_manual = loss_manual.mean()
end_manual.record()
# Compute and time auto reduction
start_auto = torch.cuda.Event(enable_timing=True)
end_auto = torch.cuda.Event(enable_timing=True)
start_auto.record()
loss_auto = F.nll_loss(
input=log_probs,
target=targets,
reduction='mean'
)
end_auto.record()
# Calculate times
torch.cuda.synchronize()
time_manual_ms = start_manual.elapsed_time(end_manual)
time_auto_ms = start_auto.elapsed_time(end_auto)
return loss_manual, time_manual_ms, loss_auto, time_auto_ms
def _compare_cpu(log_probs: torch.Tensor, targets: torch.Tensor):
# Compute and time manual reduction
start_manual = timer()
loss_manual = F.nll_loss(
input=log_probs,
target=targets,
reduction='none'
)
loss_manual = loss_manual.mean()
time_manual_ms = (timer() - start_manual) * 1000
# Compute and time auto reduction
start_auto = timer()
loss_auto = F.nll_loss(
input=log_probs,
target=targets,
reduction='mean'
)
time_auto_ms = (timer() - start_auto) * 1000
return loss_manual, time_manual_ms, loss_auto, time_auto_ms
def compare_nll_loss(device: str, logits_type: torch.dtype) -> None:
torch.manual_seed(2021)
# Reproduced example from DeepLabV3-ResNet50, on PascalVOC 2012
N, C, H, W = 64, 21, 350, 350
logits_mean = 0.0066
logits_std = 0.0487
# Create targets with shape (N, H, W) and values between [0, C-1] for C classes
targets = torch.randint(high=C, size=(N, H, W), device=device, dtype=torch.int64)
# Create logits with shape (N, C, H, W) for C classes
logits = torch.normal(mean=logits_mean, std=logits_std, size=(N, C, H, W))
logits = logits.type(logits_type).to(device)
log_probs = F.log_softmax(logits, dim=1)
if 'cuda' in device:
loss_manual, time_manual_ms, loss_auto, time_auto_ms = \
_compare_cuda(log_probs=log_probs, targets=targets)
else:
loss_manual, time_manual_ms, loss_auto, time_auto_ms = \
_compare_cpu(log_probs=log_probs, targets=targets)
print('-' * 60)
print(f'Comparison with device={device}, dtype={logits_type}')
print('-' * 60)
print(f'Manual reduction ({time_manual_ms:.3f}ms): {loss_manual}')
print(f'Auto reduction ({time_auto_ms:.3f}ms): {loss_auto}')
print('-' * 60)
if __name__ == '__main__':
# CPU tests (16-bit not supported)
compare_nll_loss(device='cpu', logits_type=torch.float64)
compare_nll_loss(device='cpu', logits_type=torch.float32)
# GPU tests
compare_nll_loss(device='cuda:0', logits_type=torch.float64)
compare_nll_loss(device='cuda:0', logits_type=torch.float32)
compare_nll_loss(device='cuda:0', logits_type=torch.float16)When run on a single GV100 with 16 CPU cores it gives the following output:
------------------------------------------------------------
Comparison with device=cpu, dtype=torch.float64
------------------------------------------------------------
Manual reduction (36.306ms): 3.045677198428085
Auto reduction (172.581ms): 3.0456771984284803
------------------------------------------------------------
------------------------------------------------------------
Comparison with device=cpu, dtype=torch.float32
------------------------------------------------------------
Manual reduction (45.777ms): 3.0456771850585938
Auto reduction (78.699ms): 3.196272850036621
------------------------------------------------------------
------------------------------------------------------------
Comparison with device=cuda:0, dtype=torch.float64
------------------------------------------------------------
Manual reduction (3.577ms): 3.0456389316191617
Auto reduction (0.844ms): 3.045638931619163
------------------------------------------------------------
------------------------------------------------------------
Comparison with device=cuda:0, dtype=torch.float32
------------------------------------------------------------
Manual reduction (3.353ms): 3.0456390380859375
Auto reduction (0.725ms): 3.0456392765045166
------------------------------------------------------------
------------------------------------------------------------
Comparison with device=cuda:0, dtype=torch.float16
------------------------------------------------------------
Manual reduction (3.197ms): 3.044921875
Auto reduction (0.578ms): nan
------------------------------------------------------------Expected behavior
I did not expect to get nan for the 16-bit CUDA case when the manual reduction gives a relatively accurate result. The other outputs are as expected. The auto-reduction is clearly faster, but the loss of precision makes it unusable for the 16-bit scenario.
Environment
Collecting environment information...
PyTorch version: 1.8.1
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31
Python version: 3.9.6 | packaged by conda-forge | (default, Jul 6 2021, 08:53:59) [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-74-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 11.3.109
GPU models and configuration: GPU 0: NVIDIA Quadro GV100
Nvidia driver version: 465.19.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.0
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn.so.8
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_adv_train.so.8
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8
/usr/local/cuda-11.3/targets/x86_64-linux/lib/libcudnn_ops_train.so.8
HIP runtime version: N/A
MIOpen runtime version: N/A
Versions of relevant libraries:
[pip3] numpy==1.21.0
[pip3] pytorch-lightning==1.3.8
[pip3] torch==1.8.1
[pip3] torchmetrics==0.4.1
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.9.1
[conda] blas 2.109 mkl conda-forge
[conda] blas-devel 3.9.0 9_mkl conda-forge
[conda] cudatoolkit 11.1.1 h6406543_8 conda-forge
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] libblas 3.9.0 9_mkl conda-forge
[conda] libcblas 3.9.0 9_mkl conda-forge
[conda] liblapack 3.9.0 9_mkl conda-forge
[conda] liblapacke 3.9.0 9_mkl conda-forge
[conda] mkl 2021.3.0 h726a3e6_557 conda-forge
[conda] mkl-devel 2021.3.0 ha770c72_558 conda-forge
[conda] mkl-include 2021.3.0 h726a3e6_557 conda-forge
[conda] numpy 1.21.0 py39hdbf815f_0 conda-forge
[conda] pytorch 1.8.1 py3.9_cuda11.1_cudnn8.0.5_0 pytorch
[conda] pytorch-lightning 1.3.8 pypi_0 pypi
[conda] torchmetrics 0.4.1 pypi_0 pypi
[conda] torchsummary 1.5.1 pypi_0 pypi
[conda] torchvision 0.9.1 py39_cu111 pytorch