Skip to content

Conversation

@chi2liu
Copy link
Contributor

@chi2liu chi2liu commented Aug 9, 2025

What does this PR do?

This PR optimizes the truncate_with_protected_tokens function in GRPO trainer by replacing list comprehensions with vectorized tensor operations. The optimization reduces CPU-GPU synchronization overhead from O(seq_len) to O(1) per sequence.

Performance improvement

The original implementation used a list comprehension that called .item() for every token in the sequence to check if it's in the protected set:

is_protected = torch.tensor([x.item() in protected_set for x in ids])

The optimized version uses torch.isin for vectorized membership testing:

is_protected = torch.isin(ids, torch.tensor(list(protected_set), device=ids.device))

Benchmark results

  • ~9x speedup on typical workloads (batch_size=32, seq_len=512)
  • Time saved: ~31ms per batch
  • Reduces CPU-GPU synchronization points from O(seq_len) to O(1) per sequence

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? (Existing tests cover this function)

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Replace list comprehension with torch.isin for protected token checking.
This reduces CPU-GPU synchronization from O(seq_len) to O(1) per sequence,
achieving ~9x speedup for typical batch sizes.
@qgallouedec qgallouedec requested a review from Copilot August 9, 2025 18:43
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR optimizes the truncate_with_protected_tokens function in the GRPO trainer by replacing inefficient list comprehensions with vectorized tensor operations. The optimization significantly reduces CPU-GPU synchronization overhead and improves performance.

Key changes:

  • Replaces list comprehension with torch.isin() for vectorized membership testing
  • Eliminates O(seq_len) CPU-GPU synchronization points per sequence
  • Achieves ~9x speedup on typical workloads

@chi2liu chi2liu requested a review from kashif August 11, 2025 09:14
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@qgallouedec qgallouedec changed the title Optimize truncate_with_protected_tokens to use vectorized operations ⚔️ Optimize truncate_with_protected_tokens to use vectorized operations Aug 17, 2025
@qgallouedec qgallouedec merged commit a6f802f into huggingface:main Aug 17, 2025
LuisVasquezBSC pushed a commit to langtech-bsc/trl that referenced this pull request Aug 28, 2025
LuisVasquezBSC pushed a commit to langtech-bsc/trl that referenced this pull request Aug 28, 2025
SamY724 pushed a commit to SamY724/trl that referenced this pull request Sep 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants