Change reorder_dimensions behavior to favor output writting sequence #28615
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
reorder_dimensions() currently iterate all the operands when determining the dimension order in the TensorIterator. It tries to move a dimension to front if any operand has a dimension whose stride is bigger than this dimension.
reorder_dimensions() do respect the case that stride has zero value. I did not see a reason why reorder_dimensions() need to keep probing each operand under regular cases.
Changed behavior a little bit.
Since operands is ordered by outputs tensor first followed by input tensor. I would favor the writing of outputs is as sequential as possible. This could make the copy between tensors with different memory format faster.
Pls correct me if this change is wrong, thanks.
Fix #26812
Benchmark on CPU
x = torch.randn(64, 2048, 7, 7).contiguous(memory_format = torch.contiguous_format)
%timeit x.contiguous(memory_format = torch.channels_last)
x = torch.randn(64, 2048, 7, 7).contiguous(memory_format = torch.channels_last)
%timeit x.contiguous(memory_format = torch.contiguous_format)
BEFORE:
20.7 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
12.5 ms ± 49.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
AFTER:
9.26 ms ± 454 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.6 ms ± 53.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Benchmark on GPU
x = torch.randn(64, 2048, 7, 7).contiguous(memory_format = torch.contiguous_format).cuda()
%timeit x.contiguous(memory_format = torch.channels_last); torch.cuda.synchronize()
x = torch.randn(64, 2048, 7, 7).contiguous(memory_format = torch.channels_last).cuda()
%timeit x.contiguous(memory_format = torch.contiguous_format); torch.cuda.synchronize()
BEFORE:
622 µs ± 268 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
5.2 µs ± 77.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
AFTER:
379 µs ± 316 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
5.25 µs ± 76.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)