Skip to content

perf(feature_extraction_sequence): skip re-splitting already-batched numpy arrays in pad()#46329

Merged
Rocketknight1 merged 1 commit into
huggingface:mainfrom
Anai-Guo:fix-seqfeat-pad-numpy-noop
Jun 2, 2026
Merged

perf(feature_extraction_sequence): skip re-splitting already-batched numpy arrays in pad()#46329
Rocketknight1 merged 1 commit into
huggingface:mainfrom
Anai-Guo:fix-seqfeat-pad-numpy-noop

Conversation

@Anai-Guo
Copy link
Copy Markdown
Contributor

@Anai-Guo Anai-Guo commented Jun 1, 2026

What does this PR do?

Fixes #46328.

SequenceFeatureExtractor.pad() normalizes its inputs with:

for key, value in processed_features.items():
    if isinstance(value[0], (int, float)):
        processed_features[key] = to_numpy(value)
    else:
        processed_features[key] = [to_numpy(v) for v in value]

When value is already a batched numpy array, the else branch rebuilds it into a
Python list of per-example arrays ([to_numpy(v) for v in value]). That iteration and
per-row copy is pure overhead and becomes very slow for large inputs (the issue reports
several minutes on a 25-minute audio file).

It is also unnecessary: the downstream _truncate/_pad logic only ever indexes
value[i], which behaves identically for a list of arrays and for a batched ndarray
(same per-example shapes, same len() for the batch dimension). So an already-batched
array can be used as-is.

This PR skips the conversion when value is already an np.ndarray. The common
list-of-arrays / list-of-tensors path is unchanged, so existing behavior is preserved.

Before submitting

Generated with Claude Code (AI-assisted, human reviewed).

…numpy arrays in pad()

When pad() receives a value that is already a numpy array, the existing code
rebuilds it as a Python list of per-element arrays via [to_numpy(v) for v in value].
For large inputs (e.g. long audio) this iteration and per-row copy is very slow and
serves no purpose: the downstream truncate/pad logic indexes value[i] identically for
both a list of arrays and a batched ndarray.

Skip the conversion when value is already an ndarray. The common list-of-arrays path
is unchanged.

Fixes huggingface#46328
@Anai-Guo Anai-Guo force-pushed the fix-seqfeat-pad-numpy-noop branch from c8c51a2 to 30a9526 Compare June 1, 2026 16:13
Copy link
Copy Markdown
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a safe optimization that shouldn't have side-effects, so LGTM!

@Rocketknight1 Rocketknight1 enabled auto-merge June 2, 2026 12:00
@Rocketknight1 Rocketknight1 added this pull request to the merge queue Jun 2, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-merge-queue github-merge-queue Bot removed this pull request from the merge queue due to failed status checks Jun 2, 2026
@Rocketknight1 Rocketknight1 added this pull request to the merge queue Jun 2, 2026
Merged via the queue into huggingface:main with commit 8d410c6 Jun 2, 2026
30 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

SequenceFeatureExtractor.pad wasting time converting numpy array to list of numpy arrays

3 participants