-
Notifications
You must be signed in to change notification settings - Fork 32.7k
Description
🐛 Bug
Information
Model I am using (Bert, XLNet ...): roberta
Language I am using the model on (English, Chinese ...): English
The problem arises when using:
- the official example scripts: (give details below)
- my own modified scripts: (give details below)
I use RobertaTokenizerFast on pretokenized text, but problem arises when I switch to slow version too
The tasks I am working on is:
- an official GLUE/SQUaD task: (give the name)
- my own task or dataset: (give details below)
I am trying to implement sliding window for roberta
To reproduce
I use tokenizer.tokenize(text) method to tokenize whole text (1-3 sentences), when I divide tokens into chunks and try to use __call__ method (I also tried encode) with is_pretokenized=True argument, but this creates additional tokens (like 3 times more then should be). I worked this around by using tokenize -> convert_tokens_to_ids -> prepare_for_model -> pad pipeline, but I believe that batch methods should be faster and more memory efficient
Steps to reproduce the behavior:
tokenizer = AutoTokenizer.from_pretrained('roberta-base', add_prefix_space=True, use_fast=True)ex_text = 'long text'tokens = tokenizer.tokenize(ex_text)examples = [tokens[i:i+126] for i in range(0, len(tokens), 100)]print(len(tokenizer(examples, is_pretokenized=True)['input_ids'][0])) # this prints more than 128
Expected behavior
I would expect to get result similar to result I get when I use
tokens = tokeniser.tokenize(ex_text)
inputs = tokenizer.convert_tokens_to_ids(tokens)
inputs = [inputs[i:i+126] for i in range(0, len(tokens), 100)]
inputs = [tokenizer.prepare_for_model(example) for example in inputs]
inputs = tokenizer.pad(inputs, padding='longest')
Am I doing something wrong or it's unexpected behaviour?
Environment info
transformersversion: 3.0.2- Platform: MacOs
- Python version: 3.8.3
- PyTorch version (GPU?): 1.5.1 (no GPU)
- Tensorflow version (GPU?): NO
- Using GPU in script?: NO
- Using distributed or parallel set-up in script?: NO
EDIT:
I see that when I use __call__ it actually treat Ġ as 2 tokens:
tokenizer(tokenizer.tokenize('How'), is_pretokenized=True)['input_ids']
out: [0, 4236, 21402, 6179, 2] where 4236, 21402 is Ġ