fix: multi-chunk batched bucket sums folding error#262
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR fixes incorrect bucket sums when the number of partial-bucket chunks exceeds the GPU kernel’s max chunk size by introducing a folding kernel and updating the accumulation logic and tests.
- Adds
segmented_left_fold_partial_bucket_sumskernel and its unit tests to fold multi-chunk partial sums. - Updates
accumulate_buckets_implto use the new fold kernel instead ofcombine_partial_bucket_sums. - Extends multiexponentiation tests to verify results when input sizes exceed the max chunk threshold.
Reviewed Changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| sxt/multiexp/bucket_method/fold_kernel.h | Defines the new segmented left‐fold GPU kernel |
| sxt/multiexp/bucket_method/fold_kernel.t.cc | Adds tests for single‐ and multi‐bucket folding |
| sxt/multiexp/bucket_method/accumulation.h | Replaces old combination kernel launch with fold kernel launch |
| sxt/multiexp/bucket_method/multiexponentiation.t.cc | Adds tests for multiexponentiation with element counts over chunk size |
| sxt/multiexp/bucket_method/BUILD | Registers the fold_kernel component |
Comments suppressed due to low confidence (2)
sxt/multiexp/bucket_method/fold_kernel.h:45
- [nitpick] The variables
bucket_group_sizeandnum_bucket_groupsmap to gridDim.x and blockDim.x but their names are confusing. Consider renaming them tonum_bucketsandnum_chunks(andchunk_indexfor threadIdx.x) to make the folding logic clearer.
auto bucket_group_size = gridDim.x;
sxt/multiexp/bucket_method/fold_kernel.t.cc:60
- The unit tests cover single‐output scenarios only. Add a SECTION that launches the kernel with
gridDim.y > 1to verify folding across multiple outputs.
// end of tests
There was a problem hiding this comment.
Sorry to ask for this, just not as familiar w/ this code and its terms. I think it'd help me understand a little if I had some mental model for how a multi-exponentiation maps to "buckets" and "bucket groups" and "elements" and "chunks" and "folds", etc. Maybe we can hop on a call or maybe text + a diagram would be sufficient
@tlovell-sxt a diagram is a good idea. I'll create one and set up a call to walk through it. |
… on multiple chunks
24ff343 to
a5a5ee9
Compare
|
🎉 This PR is included in version 1.115.1 🎉 The release is available on GitHub release Your semantic-release bot 📦🚀 |
Rationale for this change
Batch commitments with elements that are represented in 32-bytes, have 256-384 elements in the sequence, and are the same length will attempt bucket method multiexponention. The bucket method has a max chunk size of
2^20. In cases where the batch size and element length go above the max chuck size, the commitments are incorrect. For example,batch_size = 1<<3andelement_length = 1<<17will return an expected result,element_length = 1<<17 + 1will return an unexpected result. You can see test cases that reproduce the error in commit 6d6c2e4.The issue is the
mtxbk::accumulate_buckets_implmethod. In cases where the number of chunks,num_chunks, is greater than 1, the partial bucket sums get split by the max chunk size number of elements. After all the buckets are accumulated, the call to thecombine_partial_bucket_sumskernel does not handle the data offset and stride calculations to account for the multi-chunk partial bucket sums array.To solve this issue
fold_kernelis added. The purpose offold_kernelis to take the chunked partial sums and fold them into a single bucked sums array. Thefold_kernelclass performs a segmented left fold of the partial sums:out[index] = sum(partial_bucket_sums[index + i * out_size]) for i in [0, num_of_folds)where the
num_of_folds = partial_bucket_sums.size() / out.size().The tests will have to be updated when the max chunk size is increased or decreased from its
1<<20limit.What changes are included in this PR?
fold_kernelwith asegmented_left_fold_partial_bucket_sumsmethod is added to themultiexppackage with tests.accumulationwith multi-chunk cases will now usesegmented_left_fold_partial_bucket_sumswhen all the partial buckets are accumulated.multieponentiationwhich reproduce the error state. They add<2 secondsto the overall test time.Are these changes tested?
Yes. Also tested by replacing the
libblitzar-linux-x86_64.soin theblitzar-rsandnovaprojects and confirming the failing tests pass.