-
Notifications
You must be signed in to change notification settings - Fork 32.7k
[RAG] RagSequenceForGeneration: Running "retriever separately example" giving error #7829
Copy link
Copy link
Closed
Description
Environment info
transformersversion: 3.3.1- Platform: Linux-4.19.112+-x86_64-with-Ubuntu-18.04-bionic
- Python version: 3.6.9
- PyTorch version (GPU?): 1.6.0+cu101 (False)
- Tensorflow version (GPU?): 2.3.0 (False)
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help
@patrickvonplaten @LysandreJik
Information
Model I am using (Bert, XLNet ...): RAG
The problem arises when using:
- the official example scripts: (give details below)
The tasks I am working on is:
- an official GLUE/SQUaD task:
dummy_dataset
To reproduce
Steps to reproduce the behavior:
- Execute code snippets provided (Partially modified example script from https://huggingface.co/transformers/master/model_doc/rag.html)
Code snippets:
!pip install git+https://github.com/huggingface/transformers.git
!pip install datasets
!pip install faiss-cpu
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, RagSequenceForGeneration
import torch
import faiss
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True)
input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
input_ids = input_dict["input_ids"]
# Caling retriever seperately
question_hidden_states = model.question_encoder(input_ids)[0]
# 2. Retrieve
docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
print(docs_dict)
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
# 3. Forward to generator
outputs = model.generate(input_ids=input_ids, context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)
generated_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(generated_string)
Stacktrace:
AssertionError Traceback (most recent call last)
<ipython-input-5-9f622b1f6353> in <module>()
7 doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
8 # 3. Forward to generator
----> 9 outputs = model.generate(input_ids=input_ids, context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)
10 generated_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)
11 print(generated_string)
5 frames
/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
13 def decorate_context(*args, **kwargs):
14 with self:
---> 15 return func(*args, **kwargs)
16 return decorate_context
17
/usr/local/lib/python3.6/dist-packages/transformers/modeling_rag.py in generate(self, input_ids, attention_mask, context_input_ids, do_deduplication, num_return_sequences, num_beams, **kwargs)
902 # then, run model forwards to get nll scores:
903 new_input_ids = input_ids[index : index + 1].repeat(len(output_sequences), 1)
--> 904 outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
905 top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]
906
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/usr/local/lib/python3.6/dist-packages/transformers/modeling_rag.py in forward(self, input_ids, attention_mask, encoder_outputs, decoder_input_ids, decoder_attention_mask, past_key_values, context_input_ids, context_attention_mask, doc_scores, use_cache, output_attentions, output_hidden_states, output_retrieved, exclude_bos_score, reduce_loss, labels, **kwargs)
767 output_attentions=output_attentions,
768 output_hidden_states=output_hidden_states,
--> 769 output_retrieved=output_retrieved,
770 )
771
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
720 result = self._slow_forward(*input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
724 _global_forward_hooks.values(),
/usr/local/lib/python3.6/dist-packages/transformers/modeling_rag.py in forward(self, input_ids, attention_mask, encoder_outputs, decoder_input_ids, decoder_attention_mask, past_key_values, doc_scores, context_input_ids, context_attention_mask, use_cache, output_attentions, output_hidden_states, output_retrieved)
589 assert (
590 context_input_ids is not None
--> 591 ), "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
592 assert (
593 context_attention_mask is not None
AssertionError: Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.
I suspect context_input_ids is not passed to forward method. And if model is not initialised with retriever then forward function complain about missing context_input_ids or retriever. Referring to following piece of code in RagSequenceForGeneration class and generator function.
# then, run model forwards to get nll scores:
new_input_ids = input_ids[index : index + 1].repeat(len(output_sequences), 1)
outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]
Expected behavior
It should work as intended as RagTokenForGeneration do.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels