Vectorized CPU code implementing left shift operator.#88607
Vectorized CPU code implementing left shift operator.#88607alexsamardzic wants to merge 1 commit intomasterfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/88607
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 FailuresAs of commit 2650ae9: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Thanks @alexsamardzic for the PR. Although I'm very interested in this being merged since we need it for |
6b36441 to
4f38c14
Compare
jgong5
left a comment
There was a problem hiding this comment.
LGTM. I'm requesting @sanchitintel to look into it too.
lezcano
left a comment
There was a problem hiding this comment.
Great PR @alexsamardzic!
It generally looks good, but I'll let the intel guys assess whether there are any operations that can be golfed for efficiency.
| [](scalar_t a, scalar_t b) -> scalar_t { | ||
| return static_cast<std::make_unsigned_t<scalar_t>>(a) << b; | ||
| }, | ||
| [](Vectorized<scalar_t> a, Vectorized<scalar_t> b) { |
There was a problem hiding this comment.
In these lambdas, we 'pass the Vectorized<scalar_t> by value, but then, in all the operations in vec256_int and vec512_int the arguments are passed by const-ref. Which one's the better one? A priori, I'd say that pass-by-value should be the better one, but I don't know the ABI well enough as to be 100% sure on this.
Tagging a few folks that may know this: @ezyang @peterbell10 @ysiraichi
There was a problem hiding this comment.
if it's inlined, and it should be, it shouldn't matter. If you really care look at the generated assembly
There was a problem hiding this comment.
That's a fair point. See this godbolt: https://godbolt.org/z/sEbnnE97x. The assembly seems strictly better when passing by value than when passing by const&. I agree that, in the case of oneliners and short lambdas, it doesn't matter, but there's many methods from the Vectorized template that are fairly non-trivial, and may or may not be inlined. I reckon we should change them all to take their args by value.
|
@pmeier asked about vectorization for uint8 datatype... As mentioned above, this PR was motivated by his benchmark showing that left shift is slower than multiplication, for any datatype, so here is some explanation: For master, the situation is that, in BinaryOpsKernel.cpp, left shift is implemented through cpu_kernel() and multiplication is implemented through cpu_vec_kernel(). In both cases, operation is basically performed in a loop, i.e. without any vectorization, but the multiplication code goes into aten/src/ATen/vec, and this code is compiled with vectorization/AVX flags on, so it get some benefits. Thus for example for uint8 datatype, on qgpu3 and for 3x256x256 image, the benchmark reports 120us for multiplication and 150us for left shift. With the PR above, left shift implementation in BinaryOpsKernel.cpp is changed to use cpu_vec_kernel(), so the left shift code now goes into aten/src/ATen/vec too, just like multiplication. Thus for uint8 datatype, again on qgpu3 and for 3x256x256 image, the benchmark reports 120us for both multiplication and for left shift. However, for other integer datatypes, the left shift is now explicitly vectorized through AVX intrinsics, so it's considerably faster, for example for int8 the benchmark reports ~40us for the left shift. So the bottom line is that specializations for uint8_t datatype are to be created, in both vec256_int.h and vec512_int.h, for all operations vectorized explicitly there for other integer datatypes. |
| // Convert 16-bit operands from lane #0 to 32-bit values, and | ||
| // perform vectorized shift. Make sure that upper 24 bits of 32-bit | ||
| // results are all 0. | ||
| __m128i a_lo_16 = _mm256_extracti128_si256(a, 0); | ||
| __m128i b_lo_16 = _mm256_extracti128_si256(b, 0); | ||
| __m256i a_lo_32 = _mm256_cvtepi16_epi32(a_lo_16); | ||
| __m256i b_lo_32 = _mm256_cvtepi16_epi32(b_lo_16); | ||
| __m256i c_lo_32 = _mm256_and_si256(_mm256_sllv_epi32(a_lo_32, b_lo_32), mask); |
There was a problem hiding this comment.
This operation is probably memory bound anyway, but I don't think you need to cross AVX lanes here. If you shuffle neighboring pairs of int16 from a b c d ... into two vectors 0 a 0 c ... and 0 b 0 d ... then shifting should give the same answer as if you had used cvtepi16_epi32 but without needing to cross AVX lanes.
There was a problem hiding this comment.
Thanks, I've made the change and it's indeed faster when crossing lanes avoided. Will change the other two functions too, and update PR then.
4f38c14 to
9abf2e5
Compare
|
@peterbell10, @sanchitintel PR updated with the change suggested, so that lane crossings are avoided. Updated timings, "mybranch-new" are latest results, with shifts for 8-bit and 16-bit operands visibly faster. |
9abf2e5 to
76efc8b
Compare
|
The new commit 76efc8b was an attempt of refactoring to make it easier to implement right shift - apparently, it needs more work, so please ignore for now. |
76efc8b to
d02da65
Compare
(Structured so that it'll help vectorizing right shift too.)
d02da65 to
2650ae9
Compare
|
The PR is ready for review and eventual merge. As mentioned above, the shifting implementation for datatypes with bit widths not directly supported by shifting intrinsinics is refactored, so that it's easy now to add right shift implementation on top of this PR. |
|
@peterbell10's review is good enough, go ahead and merge. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 additional jobs have failed, first few of them are: trunk ,trunk / macos-12-py3-x86-64-lite-interpreter / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "flaky ci only" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR adds vectorized implementation for CPU version of left shift operator.
All of the tests run by `pytest test/test_ops.py -vk left_shift` pass.
Here are some additional details:
<details>
<summary>
Benchmarking script (writen by Philip, with small tweaks by Mario) comparing left shifts with multiplications - on par now
</summary>
```python
import torch
from torch import Tensor
from torch.utils.benchmark import Timer, Compare
from itertools import product
from functools import partial
# These functions exist, because torch.jit.script does not support `torch.iinfo`
def _num_value_bits(dtype):
if dtype == torch.uint8:
return 8
else: # torch.int32
return 31
def _max_value(dtype):
if dtype == torch.uint8:
return 255
else: # torch.int32
return 2147483647
def bitshift(image, dtype):
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
def mul(image, dtype):
input_max = float(_max_value(image.dtype))
output_max = float(_max_value(dtype))
factor = int((output_max + 1) // (input_max + 1))
image = image.to(dtype)
return image * factor
size = 256
image = torch.randint(0, 256, (3, size, size), dtype=torch.uint8)
dtype = torch.int32
def gen_inputs():
devices = ("cpu",)
fns = (mul, bitshift)
threads = (1,)
for device, fn, threads in product(devices, fns, threads):
yield f"Bitshift {device} {image.dtype}", str(tuple(image.shape)), threads, fn, image, dtype
def benchmark(label, sub_label, threads, f, *args, **kwargs):
return Timer("f(*args, **kwargs)",
globals=locals(),
label=label,
description=f.__name__,
sub_label=sub_label,
num_threads=threads).blocked_autorange()
results = []
for args in gen_inputs():
results.append(benchmark(*args))
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
```
</details>
<details>
<summary>
Test script exercising large number of combinations of left shift operands that I've used for further testing (validates results through comparing with results generated by NumPy)
</summary>
```python
import numpy as np
import torch
# Testing shifting of non-negative numbers only, but will test all
# possible RHS shift values for given type. For int8 and int16, we'll
# test shifting all of non-negative values represntable by type. For
# the rest of data types, we'll test shifting some random numbers in
# the corresponding range.
def _create_inputs(dtype):
info = torch.iinfo(dtype)
if dtype == torch.int8 or dtype == torch.int16:
ntests = info.max + 1
x = torch.arange(info.max + 1, dtype=dtype, device="cpu", requires_grad=False)
else:
ntests = 100000
x = torch.randint(info.max + 1 if dtype != torch.int64 else info.max, (ntests,), dtype=dtype, device="cpu", requires_grad=False)
y = torch.tensor(range(info.bits), dtype=dtype, device="cpu", requires_grad=False)
xy = torch.cartesian_prod(x, y)
return (xy[:, 0], xy[:, 1])
torch.manual_seed(0)
# Perform testing for each datatype supported, and compare results
# with ones generated by numpy.
for dtype in (torch.int8, torch.int16, torch.int32, torch.int64):
(x, y) = _create_inputs(dtype)
z = x << y
xnp = x.numpy()
ynp = y.numpy()
znp = z.numpy()
assert((znp == (xnp << ynp)).all())
```
</details>
<details>
<summary>
Benchmarking script running the left shift operator on tensors of different length (and varying number of bits to shift)
</summary>
```python
import torch
import pickle
import itertools
from torch.utils.benchmark import Timer, Compare
torch.manual_seed(0)
# Edit this part if needed.
lengths = [1024, 4096, 16384, 65536]
rhss = [1, 2, 7, 8, 15, 16, 31, 32, 63, 64]
benchmark_name = "lshift"
label = ""
dtypes = [torch.int8, torch.int16, torch.int32, torch.int64]
results = []
# Create an argument pair for testing. Argument are tensors of given
# datatype and length, LHS for each shift operation is a random
# number, and RHS is given value that is same for all of them.
def _make_args(dtype, length, rhs):
info = torch.iinfo(dtype)
imax = info.max
return (torch.randint(info.max, (length,), dtype=dtype, device="cpu", requires_grad=False),
rhs * torch.ones((length,), dtype=dtype, device="cpu", requires_grad=False))
# Run shift operation for vectors of given lenghts and for given
# number of bits to be shifted, and remember timings.
for dtype, length, rhs in itertools.product(dtypes, lengths, rhss):
x, y = _make_args(dtype, length, rhs)
timer = Timer("x << y",
globals=globals(),
label=benchmark_name,
description=label,
sub_label=f"dtype={dtype},length={length}",
num_threads=1)
results.append(timer.blocked_autorange())
# Gather results.
compare = Compare(results)
compare.trim_significant_figures()
compare.print()
# Print results.
with open("{}.pickle".format(label), "wb") as f:
pickle.dump(results, f)
```
</details>
<details>
<summary>
Results of running above benchmarking script - results manually merged for runs of viable/strict (labeled "master" in the table below) and my branch (labeled "mybranch" in the table below)
</summary>
```
[------------------- lshift -------------------------------]
| master | mybranch
1 threads: ------------------------------------------------
dtype=torch.int8,length=1024 | 3 | 3
dtype=torch.int8,length=4096 | 5 | 3
dtype=torch.int8,length=16384 | 14 | 5
dtype=torch.int8,length=65536 | 51 | 15
dtype=torch.int16,length=1024 | 3 | 3
dtype=torch.int16,length=4096 | 4 | 3
dtype=torch.int16,length=16384 | 11 | 5
dtype=torch.int16,length=65536 | 39 | 13
dtype=torch.int32,length=1024 | 3 | 2
dtype=torch.int32,length=4096 | 4 | 3
dtype=torch.int32,length=16384 | 10 | 4
dtype=torch.int32,length=65536 | 35 | 12
dtype=torch.int64,length=1024 | 3 | 3
dtype=torch.int64,length=4096 | 4 | 3
dtype=torch.int64,length=16384 | 11 | 6
dtype=torch.int64,length=65536 | 36 | 20
Times are in microseconds (us).
```
</details>
All of the testing/benchmarking was conducted on qpu3, that supports AVX2 only. For basic validation of AVX-512 update of left shift implementation for 8-bit operands (that is the only one that is non-trivial in AVX-512 case), [Compiler Explorer](https://godbolt.org/) is used, with GCC trunk and `-mavx512f -mavx512bw` flags added. Here are further details:
<details>
<summary>
C program used for basic validation of AVX-512 vectorized version for 8-bit operands
</summary>
```
#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <immintrin.h>
static void print_m512i_int8(const __m512i* x)
{
int8_t val[64];
memcpy(val, x, sizeof(val));
for (int i = 0; i < 64; ++i) {
if (i > 0)
printf(", ");
printf("%d", (int)val[i]);
}
printf("\n");
}
int main()
{
__m512i a = _mm512_set_epi8(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1);
__m512i b = _mm512_set_epi8(7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 6, 6,
5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4,
3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2,
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
0);
// ------- Copied code from vec512_int.h
// Mask used to set upper 8 bits of each 16-bit value to 0, and keep
// lower 8 bits.
__m512i mask = _mm512_set_epi16(0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff);
// Convert 8-bit operands from lower lanes to 16-bit values, and
// perform vectorized shift. Make sure that upper 8 bits of 16-bit
// results are all 0.
__m256i a_lo_8 = _mm512_extracti64x4_epi64(a, 0);
__m256i b_lo_8 = _mm512_extracti64x4_epi64(b, 0);
__m512i a_lo_16 = _mm512_cvtepi8_epi16(a_lo_8);
__m512i b_lo_16 = _mm512_cvtepi8_epi16(b_lo_8);
__m512i c_lo_16 = _mm512_and_si512(_mm512_sllv_epi16(a_lo_16, b_lo_16), mask);
// Convert 8-bit operands from upper lanes to 16-bit values, and
// perform vectorized shift. Make sure that upper 8 bits of 16-bit
// results are all 0.
__m256i a_hi_8 = _mm512_extracti64x4_epi64(a, 1);
__m256i b_hi_8 = _mm512_extracti64x4_epi64(b, 1);
__m512i a_hi_16 = _mm512_cvtepi8_epi16(a_hi_8);
__m512i b_hi_16 = _mm512_cvtepi8_epi16(b_hi_8);
__m512i c_hi_16 = _mm512_and_si512(_mm512_sllv_epi16(a_hi_16, b_hi_16), mask);
// Cast 16-bit results back into 8-bit values and merge them
// together (using unsigned saturation with higher 8 bits set to 0
// above ensures that results are correct). Values are merged per
// lanes, so this is not yet the final result.
__m512i c_perm = _mm512_packus_epi16(c_lo_16, c_hi_16);
// Permute values so that final result is produced.
__m512i idx = _mm512_set_epi64(7, 5, 3, 1, 6, 4, 2, 0);
__m512i c = _mm512_permutexvar_epi64(idx, c_perm);
// ------- End copied
print_m512i_int8(&c);
// Expected output: 1(x8), 2(x8), 4(x8), 8(x8), 16(x8), 32(x8), 64(x8), 128(x8), -128(x8)
return 0;
}
```
</details>
Pull Request resolved: pytorch#88607
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/peterbell10
This PR adds vectorized implementation for CPU version of left shift operator.
All of the tests run by
pytest test/test_ops.py -vk left_shiftpass.Here are some additional details:
Benchmarking script (writen by Philip, with small tweaks by Mario) comparing left shifts with multiplications - on par now
Test script exercising large number of combinations of left shift operands that I've used for further testing (validates results through comparing with results generated by NumPy)
Benchmarking script running the left shift operator on tensors of different length (and varying number of bits to shift)
Results of running above benchmarking script - results manually merged for runs of viable/strict (labeled "master" in the table below) and my branch (labeled "mybranch" in the table below)
All of the testing/benchmarking was conducted on qpu3, that supports AVX2 only. For basic validation of AVX-512 update of left shift implementation for 8-bit operands (that is the only one that is non-trivial in AVX-512 case), Compiler Explorer is used, with GCC trunk and
-mavx512f -mavx512bwflags added. Here are further details:C program used for basic validation of AVX-512 vectorized version for 8-bit operands
cc @VitalyFedyunin @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10