Skip to content

Conversation

@glaringlee
Copy link
Contributor

@glaringlee glaringlee commented Oct 24, 2019

reorder_dimensions() currently iterate all the operands when determining the dimension order in the TensorIterator. It tries to move a dimension to front if any operand has a dimension whose stride is bigger than this dimension.

reorder_dimensions() do respect the case that stride has zero value. I did not see a reason why reorder_dimensions() need to keep probing each operand under regular cases.

Changed behavior a little bit.
Since operands is ordered by outputs tensor first followed by input tensor. I would favor the writing of outputs is as sequential as possible. This could make the copy between tensors with different memory format faster.

Pls correct me if this change is wrong, thanks.

Fix #26812

Benchmark on CPU
x = torch.randn(64, 2048, 7, 7).contiguous(memory_format = torch.contiguous_format)
%timeit x.contiguous(memory_format = torch.channels_last)
x = torch.randn(64, 2048, 7, 7).contiguous(memory_format = torch.channels_last)
%timeit x.contiguous(memory_format = torch.contiguous_format)
BEFORE:
20.7 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
12.5 ms ± 49.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
AFTER:
9.26 ms ± 454 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
12.6 ms ± 53.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Benchmark on GPU
x = torch.randn(64, 2048, 7, 7).contiguous(memory_format = torch.contiguous_format).cuda()
%timeit x.contiguous(memory_format = torch.channels_last); torch.cuda.synchronize()
x = torch.randn(64, 2048, 7, 7).contiguous(memory_format = torch.channels_last).cuda()
%timeit x.contiguous(memory_format = torch.contiguous_format); torch.cuda.synchronize()
BEFORE:
622 µs ± 268 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
5.2 µs ± 77.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
AFTER:
379 µs ± 316 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
5.25 µs ± 76.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@glaringlee has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@VitalyFedyunin
Copy link
Contributor

CC @ngimel

@VitalyFedyunin VitalyFedyunin self-requested a review October 24, 2019 20:40
@facebook-github-bot
Copy link
Contributor

@glaringlee merged this pull request in 7ed9a3e.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 28, 2019
…(#28615)

Summary:
reorder_dimensions() currently iterate all the operands when determining the dimension order in the TensorIterator. It tries to move a dimension to front if any operand has a dimension whose stride is bigger than this dimension.

reorder_dimensions() do respect the case that stride has zero value. I did not see a reason why reorder_dimensions() need to keep probing each operand under regular cases.

Changed behavior a little bit.
Since operands is ordered by outputs tensor first followed by input tensor.  I would favor the writing of outputs is as sequential as possible. This could make the copy between tensors with different memory format faster.

Pls correct me if this change is wrong, thanks.
Pull Request resolved: pytorch/pytorch#28615

Reviewed By: VitalyFedyunin

Differential Revision: D18122474

Pulled By: glaringlee

fbshipit-source-id: f36467489fe6c6514b14ce9dcc439628d5d5ad0e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

TensorIterator traverse order and write locality

4 participants