-
Notifications
You must be signed in to change notification settings - Fork 2.4k
⚔️ Optimize truncate_with_protected_tokens to use vectorized operations #3875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
⚔️ Optimize truncate_with_protected_tokens to use vectorized operations #3875
Conversation
Replace list comprehension with torch.isin for protected token checking. This reduces CPU-GPU synchronization from O(seq_len) to O(1) per sequence, achieving ~9x speedup for typical batch sizes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR optimizes the truncate_with_protected_tokens function in the GRPO trainer by replacing inefficient list comprehensions with vectorized tensor operations. The optimization significantly reduces CPU-GPU synchronization overhead and improves performance.
Key changes:
- Replaces list comprehension with
torch.isin()for vectorized membership testing - Eliminates O(seq_len) CPU-GPU synchronization points per sequence
- Achieves ~9x speedup on typical workloads
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…ns (huggingface#3875) Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
…ns (huggingface#3875) Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
…ns (huggingface#3875) Co-authored-by: Kashif Rasul <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
What does this PR do?
This PR optimizes the
truncate_with_protected_tokensfunction in GRPO trainer by replacing list comprehensions with vectorized tensor operations. The optimization reduces CPU-GPU synchronization overhead from O(seq_len) to O(1) per sequence.Performance improvement
The original implementation used a list comprehension that called
.item()for every token in the sequence to check if it's in the protected set:The optimized version uses
torch.isinfor vectorized membership testing:Benchmark results
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.