Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 72 additions & 70 deletions aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

#include <thrust/execution_policy.h>
#include <thrust/unique.h>
#include <thrust/device_vector.h>


namespace at {
namespace native {
Expand Down Expand Up @@ -82,7 +82,8 @@ __global__ void compute_grad_weight_bags(
int64_t *offset2bag, int64_t *count, ptrdiff_t numel,
int64_t stride, int mode_mean, const int64_t *bag_size,
scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
int64_t* segment_offsets, int64_t num_of_segments, scalar_t *grad_weight_per_segment,
int64_t* segment_offsets, int64_t num_of_segments,
acc_type<scalar_t, true> *grad_weight_per_segment,
const int64_t stride_warped) {

const int gid = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -126,7 +127,7 @@ __global__ void compute_grad_weight(
int64_t stride,
int64_t* segment_offsets,
int64_t num_of_segments,
scalar_t *grad_weight_per_segment,
acc_type<scalar_t, true> *grad_weight_per_segment,
int padding_idx,
const int64_t stride_warped) {

Expand All @@ -142,9 +143,6 @@ __global__ void compute_grad_weight(
}
const int idx_begin = segment_offsets[id];
const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
if (idx_begin == padding_idx) {
return;
}

accscalar_t weight = 0;
for (int idx=idx_begin; idx < idx_end; ++idx) {
Expand All @@ -161,7 +159,8 @@ __global__ void compute_grad_weight(
template <typename scalar_t>
__global__ void sum_and_scatter(
int64_t *input, scalar_t *gradWeight, int64_t stride,
int64_t* segment_offsets, int64_t num_of_segments, const scalar_t *grad_weight_per_segment,
int64_t* segment_offsets, int64_t num_of_segments,
const acc_type<scalar_t, true> *grad_weight_per_segment,
const int64_t *segment_sizes_offsets, int64_t num_of_partial_segments,
const int64_t stride_warped) {

Expand Down Expand Up @@ -212,7 +211,7 @@ Tensor embedding_backward_cuda_kernel(
// spawn a warp per index. In this context, a segment is a number of rows that should
// be summarized.
// Unit: index in `sorted_indices` and `orig_indices`
thrust::device_vector<int64_t> segment_offsets(numel);
auto segment_offsets = at::empty({numel}, orig_indices.options());
int64_t num_of_segments;
{
auto sorted_indices_dev = thrust::device_ptr<int64_t>(sorted_indices.data<int64_t>());
Expand All @@ -224,18 +223,18 @@ Tensor embedding_backward_cuda_kernel(
sorted_indices_dev + numel,
thrust::make_counting_iterator(0),
dummy_dev,
thrust::raw_pointer_cast(segment_offsets.data()));
thrust::device_ptr<int64_t>(segment_offsets.data<int64_t>()));
num_of_segments = thrust::get<0>(ends) - dummy_dev;
}

// We split the segments up into sizes of `NROWS_PER_THREAD`
// Compute the number partial-segments per segment (some partial-segments
// may not be the full `NROWS_PER_THREAD` number of rows)
thrust::device_vector<int64_t> partials_per_segment(num_of_segments);
auto partials_per_segment = at::empty({num_of_segments}, orig_indices.options());
{
krn_partials_per_segment<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
thrust::raw_pointer_cast(partials_per_segment.data()),
thrust::raw_pointer_cast(segment_offsets.data()),
partials_per_segment.data<int64_t>(),
segment_offsets.data<int64_t>(),
num_of_segments,
numel);
}
Expand All @@ -244,82 +243,85 @@ Tensor embedding_backward_cuda_kernel(
// of each partial-segment in `sorted_indices`, we need to compute the
// start position of each _segment_ in `partial_segment_offset`.
// Unit: index in `partial_segment_offset`
thrust::device_vector<int64_t> partials_per_segment_offset(num_of_segments);
auto partials_per_segment_offset = at::empty({num_of_segments}, orig_indices.options());
thrust::exclusive_scan(
policy,
partials_per_segment.begin(),
partials_per_segment.end(),
partials_per_segment_offset.begin());
thrust::device_ptr<int64_t>(partials_per_segment.data<int64_t>()),
thrust::device_ptr<int64_t>(partials_per_segment.data<int64_t>()+num_of_segments),
thrust::device_ptr<int64_t>(partials_per_segment_offset.data<int64_t>()));

// The total number of partial-segments is the sum of `partials_per_segment_offset`
const int num_of_partial_segments = partials_per_segment[num_of_segments-1] +
partials_per_segment_offset[num_of_segments-1];
const int num_of_partial_segments = partials_per_segment[num_of_segments-1].item<int64_t>() +
partials_per_segment_offset[num_of_segments-1].item<int64_t>();

// Now we can compute the start position of each partial-segment
// Unit: index in `sorted_indices` and `orig_indices`
thrust::device_vector<int64_t> partial_segment_offset(num_of_partial_segments);
auto partial_segment_offset = at::empty({num_of_partial_segments}, orig_indices.options());
{
krn_partial_segment_offset<<<ceil_div(num_of_segments, 32), 32, 0, stream>>> (
thrust::raw_pointer_cast(partial_segment_offset.data()),
thrust::raw_pointer_cast(partials_per_segment.data()),
thrust::raw_pointer_cast(partials_per_segment_offset.data()),
thrust::raw_pointer_cast(segment_offsets.data()),
partial_segment_offset.data<int64_t>(),
partials_per_segment.data<int64_t>(),
partials_per_segment_offset.data<int64_t>(),
segment_offsets.data<int64_t>(),
num_of_segments);
}

auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, grad.options());
const int stride_warped = ceil_div(stride, WARP_SIZE)*WARP_SIZE;
const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
const int grid = ceil_div(num_of_partial_segments*stride_warped, block);

// Compute the sum of each partial-segment and handle bags
if (offset2bag.defined()) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data<int64_t>(),
grad.data<scalar_t>(),
offset2bag.data<int64_t>(),
count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
mode_mean, bag_size.data<int64_t>(),
per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
thrust::raw_pointer_cast(partial_segment_offset.data()),
num_of_partial_segments, grad_weight_per_segment.data<scalar_t>(),
stride_warped);
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data<int64_t>(),
grad.data<scalar_t>(),
count.defined() ? count.data<int64_t>() : nullptr,
numel, stride,
thrust::raw_pointer_cast(partial_segment_offset.data()),
num_of_partial_segments,
grad_weight_per_segment.data<scalar_t>(),
padding_idx,
stride_warped);
});
}
THCudaCheck(cudaGetLastError());

// Finally, we sum all the partial-sums and scatter them
// into `grad_weight`.
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "embedding_bag_backward_cuda_sum_and_scatter", [&] {
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
sorted_indices.data<int64_t>(),
grad_weight.data<scalar_t>(),
stride,
thrust::raw_pointer_cast(segment_offsets.data()),
num_of_segments, grad_weight_per_segment.data<scalar_t>(),
thrust::raw_pointer_cast(partials_per_segment_offset.data()),
num_of_partial_segments, stride_warped);
grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
// For numerical stability, the dtype of `grad_weight_per_segment`
// should match `acc_type`
using partial_weight_t = acc_type<scalar_t, true>;
TensorOptions op;
if(grad.dtype() == at::kHalf) {
op = grad.options().dtype(at::kFloat);
} else {
op = grad.options();
}
auto grad_weight_per_segment = at::empty({num_of_partial_segments, stride}, op);
// Compute the sum of each partial-segment and handle bags
if (offset2bag.defined()) {
compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data<int64_t>(),
grad.data<scalar_t>(),
offset2bag.data<int64_t>(),
count.defined() ? count.data<int64_t>() : nullptr, numel, stride,
mode_mean, bag_size.data<int64_t>(),
per_sample_weights.defined() ? per_sample_weights.data<scalar_t>() : NULL,
per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
partial_segment_offset.data<int64_t>(),
num_of_partial_segments, grad_weight_per_segment.data<partial_weight_t>(),
stride_warped);
} else {
compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
orig_indices.data<int64_t>(),
grad.data<scalar_t>(),
count.defined() ? count.data<int64_t>() : nullptr,
numel, stride,
partial_segment_offset.data<int64_t>(),
num_of_partial_segments,
grad_weight_per_segment.data<partial_weight_t>(),
padding_idx,
stride_warped);
}
THCudaCheck(cudaGetLastError());

// Finally, we sum all the partial-sums and scatter them
// into `grad_weight`.
const int grid2 = ceil_div(num_of_segments*stride_warped, block);
sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
sorted_indices.data<int64_t>(),
grad_weight.data<scalar_t>(),
stride,
segment_offsets.data<int64_t>(),
num_of_segments, grad_weight_per_segment.data<partial_weight_t>(),
partials_per_segment_offset.data<int64_t>(),
num_of_partial_segments, stride_warped);
THCudaCheck(cudaGetLastError());
});
THCudaCheck(cudaGetLastError());
return grad_weight;
}

Expand Down