Skip to content

How to use XLM-R as retriever correctly? #506

@khalidbhs

Description

@khalidbhs

I'm trying to use xlm-r-100langs-bert-base-nli-stsb-mean-tokens as retriever with

retriever = EmbeddingRetriever(document_store=document_store, embedding_model='xlm-r-100langs-bert-base-nli-stsb-mean-tokens', model_format='sentence_transformers')

when I try to embed a text with retriever.embed('test'), it raises this error:

/usr/local/lib/python3.6/dist-packages/transformers/modeling_utils.py in get_extended_attention_mask(self, attention_mask, input_shape, device)
    260             raise ValueError(
    261                 "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
--> 262                     input_shape, attention_mask.shape
    263                 )
    264             )

ValueError: Wrong shape for input_ids (shape torch.Size([4])) or attention_mask (shape torch.Size([4]))

I also tried to use the model from huggingface model hub:

retriever = EmbeddingRetriever(document_store=document_store, embedding_model='sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens', model_format='transformers')

but it raises this error:

TypeError                                 Traceback (most recent call last)

<ipython-input-34-0b021b13e848> in <module>()
      1 from haystack.retriever.dense import EmbeddingRetriever
----> 2 retriever = EmbeddingRetriever(document_store=document_store, embedding_model='sentence-transformers/xlm-r-100langs-bert-base-nli-stsb-mean-tokens', model_format='transformers')

6 frames

/usr/local/lib/python3.6/dist-packages/haystack/retriever/dense.py in __init__(self, document_store, embedding_model, use_gpu, model_format, pooling_strategy, emb_extraction_layer)
    300             self.embedding_model = Inferencer.load(
    301                 embedding_model, task_type="embeddings", extraction_strategy=self.pooling_strategy,
--> 302                 extraction_layer=self.emb_extraction_layer, gpu=use_gpu, batch_size=4, max_seq_len=512, num_processes=0
    303             )
    304 

/usr/local/lib/python3.6/dist-packages/farm/infer.py in load(cls, model_name_or_path, batch_size, gpu, task_type, return_class_probs, strict, max_seq_len, doc_stride, extraction_layer, extraction_strategy, s3e_stats, num_processes, disable_tqdm, tokenizer_class, use_fast, tokenizer_args, dummy_ph, benchmarking)
    271                                        tokenizer_class=tokenizer_class,
    272                                        use_fast=use_fast,
--> 273                                        **tokenizer_args,
    274                                        )
    275 

/usr/local/lib/python3.6/dist-packages/farm/modeling/tokenization.py in load(cls, pretrained_model_name_or_path, tokenizer_class, use_fast, **kwargs)
    131                 ret = BertTokenizerFast.from_pretrained(pretrained_model_name_or_path, **kwargs)
    132             else:
--> 133                 ret = BertTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
    134         elif tokenizer_class == "XLNetTokenizer":
    135             if use_fast:

/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils_base.py in from_pretrained(cls, *inputs, **kwargs)
   1423 
   1424         """
-> 1425         return cls._from_pretrained(*inputs, **kwargs)
   1426 
   1427     @classmethod

/usr/local/lib/python3.6/dist-packages/transformers/tokenization_utils_base.py in _from_pretrained(cls, pretrained_model_name_or_path, *init_inputs, **kwargs)
   1570         # Instantiate tokenizer.
   1571         try:
-> 1572             tokenizer = cls(*init_inputs, **init_kwargs)
   1573         except OSError:
   1574             raise OSError(

/usr/local/lib/python3.6/dist-packages/transformers/tokenization_bert.py in __init__(self, vocab_file, do_lower_case, do_basic_tokenize, never_split, unk_token, sep_token, pad_token, cls_token, mask_token, tokenize_chinese_chars, strip_accents, **kwargs)
    189         )
    190 
--> 191         if not os.path.isfile(vocab_file):
    192             raise ValueError(
    193                 "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "

/usr/lib/python3.6/genericpath.py in isfile(path)
     28     """Test whether a path is a regular file"""
     29     try:
---> 30         st = os.stat(path)
     31     except OSError:
     32         return False

TypeError: stat: path should be string, bytes, os.PathLike or integer, not NoneType

Any advice how to use the xlm-r-100langs-bert-base-nli-stsb-mean-tokens model correctly?

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions