-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Inconsistent handling of empty inputs when truncating using LongestFirst #381
Description
Description
The behavior of truncation when using the LongestFirst strategy is inconsistent with respect to empty inputs with pairs. In particular, if one element of the pair is empty, and the other has length greater than the maximum (e.g. truncation needs to be performed), then a TruncationError::MaxLengthTooLow error is raised. However, if both inputs are empty, or one input is empty and truncation does not occur (because the other has length is less than the maximum), no error is raised.
I initially created this issue in the transformers repo, huggingface/transformers#6669, however after looking into it a bit more I believe this issue seems isolated to tokenizers.
Example
from transformers import BertTokenizer, BertTokenizerFast
tokz = BertTokenizer.from_pretrained('bert-base-uncased')
tokz_fast = BertTokenizerFast.from_pretrained('bert-base-uncased')
empty = ''
short = 'the ' * 509
long = 'the ' * 510
# Case 1: no truncation, no error
tokz(empty, empty, padding=True, truncation='longest_first', return_tensors='pt', max_length=512)
tokz_fast(empty, empty, padding=True, truncation='longest_first', return_tensors='pt', max_length=512)
# Case 2: no truncation, no error
tokz(empty, short, padding=True, truncation='longest_first', return_tensors='pt', max_length=512)
tokz_fast(empty, short, padding=True, truncation='longest_first', return_tensors='pt', max_length=512)
# Case 3: truncation, no error
tokz(long, long, padding=True, truncation='longest_first', return_tensors='pt', max_length=512)
tokz_fast(long, long, padding=True, truncation='longest_first', return_tensors='pt', max_length=512)
# Case 4: truncation, Truncation error from BertTokenizerFast only
tokz(empty, long, padding=True, truncation='longest_first', return_tensors='pt', max_length=512)
tokz_fast(empty, long, padding=True, truncation='longest_first', return_tensors='pt', max_length=512)Diagnosis
This issue occurs because the checks in the truncate_encodings function are inconsistent. In particular, the check on truncation.rs:L82 checks if the total length is less than the maximum. This allows empty inputs. However, the check on truncation.rs:L100 explicitly disallows the length of either input after truncation to be zero.
Proposal
I propose that the truncation strategy made consistent with itself, and the transformers package. A minimal change would involve eliminating the second check on line 100. However, I would also propose replacing the O(n) iterative algorithm currently used to implement LongestFirst truncation with a constant time algorithm. Here is a Python sketch of the proposed algorithm.
import math
def truncate_longest(a, b, max_length):
if max_length <= 0:
raise ValueError(f'max_length must be greater than 0 (got: {max_length})')
n1 = len(a)
n2 = len(b)
if n1 + n2 <= max_length:
return a, b
# Make a the shorter list
swap = False
if n1 > n2:
swap = True
a, b = b, a
n1, n2 = n2, n1
# Set length of the longer list to the larger of:
# - The remaining length after removing the length of the shorter list.
# - The length of shorter list.
n2 = max(n1, max_length - n1)
if n1 + n2 > max_length:
# Need to truncate.
# If max_length not divisible by 2,
# then let the initially longer list
# get the extra token.
half_length = max_length / 2
n1 = math.floor(half_length)
n2 = math.ceil(half_length)
a, b = a[:n1], b[:n2]
if swap:
return b, a
return a, bIf this change is acceptable, I'd be happy to create and submit a PR. Is there a CONTRIBUTING guide for tokenizers anywhere? I've never worked with Rust.
edit: grammar