Skip to content

[RAG] RagSequenceForGeneration: Running "retriever separately example" giving error #7829

@lalitpagaria

Description

@lalitpagaria

Environment info

  • transformers version: 3.3.1
  • Platform: Linux-4.19.112+-x86_64-with-Ubuntu-18.04-bionic
  • Python version: 3.6.9
  • PyTorch version (GPU?): 1.6.0+cu101 (False)
  • Tensorflow version (GPU?): 2.3.0 (False)
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

@patrickvonplaten @LysandreJik

Information

Model I am using (Bert, XLNet ...): RAG

The problem arises when using:

  • the official example scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: dummy_dataset

To reproduce

Steps to reproduce the behavior:

  1. Execute code snippets provided (Partially modified example script from https://huggingface.co/transformers/master/model_doc/rag.html)

Code snippets:

!pip install git+https://github.com/huggingface/transformers.git
!pip install datasets
!pip install faiss-cpu

from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration, RagSequenceForGeneration
import torch
import faiss

tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
retriever = RagRetriever.from_pretrained("facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True)

input_dict = tokenizer.prepare_seq2seq_batch("How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="pt")
input_ids = input_dict["input_ids"]

# Caling retriever seperately

question_hidden_states = model.question_encoder(input_ids)[0]
# 2. Retrieve
docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
print(docs_dict)
doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
# 3. Forward to generator
outputs = model.generate(input_ids=input_ids, context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)

generated_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)
print(generated_string)

Stacktrace:

AssertionError                            Traceback (most recent call last)
<ipython-input-5-9f622b1f6353> in <module>()
      7 doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)).squeeze(1)
      8 # 3. Forward to generator
----> 9 outputs = model.generate(input_ids=input_ids, context_input_ids=docs_dict["context_input_ids"], context_attention_mask=docs_dict["context_attention_mask"], doc_scores=doc_scores)
     10 generated_string = tokenizer.batch_decode(outputs, skip_special_tokens=True)
     11 print(generated_string)

5 frames
/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     13         def decorate_context(*args, **kwargs):
     14             with self:
---> 15                 return func(*args, **kwargs)
     16         return decorate_context
     17 

/usr/local/lib/python3.6/dist-packages/transformers/modeling_rag.py in generate(self, input_ids, attention_mask, context_input_ids, do_deduplication, num_return_sequences, num_beams, **kwargs)
    902             # then, run model forwards to get nll scores:
    903             new_input_ids = input_ids[index : index + 1].repeat(len(output_sequences), 1)
--> 904             outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
    905             top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]
    906 

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/transformers/modeling_rag.py in forward(self, input_ids, attention_mask, encoder_outputs, decoder_input_ids, decoder_attention_mask, past_key_values, context_input_ids, context_attention_mask, doc_scores, use_cache, output_attentions, output_hidden_states, output_retrieved, exclude_bos_score, reduce_loss, labels, **kwargs)
    767             output_attentions=output_attentions,
    768             output_hidden_states=output_hidden_states,
--> 769             output_retrieved=output_retrieved,
    770         )
    771 

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/transformers/modeling_rag.py in forward(self, input_ids, attention_mask, encoder_outputs, decoder_input_ids, decoder_attention_mask, past_key_values, doc_scores, context_input_ids, context_attention_mask, use_cache, output_attentions, output_hidden_states, output_retrieved)
    589                 assert (
    590                     context_input_ids is not None
--> 591                 ), "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function."
    592                 assert (
    593                     context_attention_mask is not None

AssertionError: Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.

I suspect context_input_ids is not passed to forward method. And if model is not initialised with retriever then forward function complain about missing context_input_ids or retriever. Referring to following piece of code in RagSequenceForGeneration class and generator function.

            # then, run model forwards to get nll scores:
            new_input_ids = input_ids[index : index + 1].repeat(len(output_sequences), 1)
            outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
            top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]

Expected behavior

It should work as intended as RagTokenForGeneration do.

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions