Skip to content

Commit 015bfc9

Browse files
Update
[ghstack-poisoned]
1 parent 8d469be commit 015bfc9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

aten/src/ATen/native/mps/operations/ScanKernel.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ static void scan_with_indices_mps_impl(const Tensor& self,
236236
constexpr int bn = 32;
237237
size_t stride_blocks = (stride + bn - 1) / bn;
238238

239-
mtl_setArgs<3>(computeEncoder, axis_siz, stride, stride_blocks);
239+
mtl_setArgs<3>(computeEncoder, axis_size, stride, stride_blocks);
240240

241241
int n_reads = (input_tensor.element_size() <= 4) ? 4 : 2;
242242
int n_simdgroups = bn / n_reads;

0 commit comments

Comments
 (0)