Skip to content

Commit 9f8c51c

Browse files
committed
Add missing boundary checks
This fixes OOB memory access for followng code ``` import torch qk = torch.randn((1024,587), dtype=torch.float64, device='cuda') smqk = torch.softmax(qk, dim=-1) ```
1 parent 99c8d5a commit 9f8c51c

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

aten/src/ATen/native/cuda/SoftMax.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,7 @@ ilpReduce(index_t shift,
465465
if(shift > 0){
466466
data -= shift;
467467
size += shift;
468-
if(threadIdx.x >= shift){
468+
if (offset >= shift && offset < size) {
469469
threadVal = r(threadVal, data[offset]);
470470
}
471471
size -= blockDim.x;
@@ -515,7 +515,7 @@ WriteFpropResultsVectorized(
515515
output -= shift;
516516
size += shift;
517517

518-
if (threadIdx.x >= shift) {
518+
if (offset >= shift && offset < size) {
519519
output[offset] = epilogue(input[offset]);
520520
}
521521
size -= blockDim.x;

0 commit comments

Comments
 (0)