
     `i_\                    l   d Z ddlZddlmZ ddlmZmZmZ ddlZddlm	Z	 ddl
mZmZ ddlmZ dd	lmZmZmZmZ dd
lmZ ddlmZ ddlmZmZ ddlmZ ddlmZ  ej        e           Z!e ed           G d de                                  Z"ee G d de                                  Z# ed          e G d de                                  Z$e G d de$                      Z% ed           G d de$                      Z& ed           G d d e$e                      Z'g d!Z(dS )"zRAG model implementation.    N)	dataclass)CallableOptionalUnion)nn   )CacheEncoderDecoderCache)PretrainedConfig)GenerationConfigGenerationMixinLogitsProcessorListStoppingCriteriaList)ModelOutput)PreTrainedModel)auto_docstringlogging   )	RagConfig)RagRetrieverzI
    Base class for retriever augmented marginalized models outputs.
    )custom_introc                      e Zd ZU dZdZeej                 ed<   dZ	eej                 ed<   dZ
eej                 ed<   dZee         ed<   dZeej                 ed<   dZeej                 ed<   dZeej                 ed	<   dZeej                 ed
<   dZeej                 ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dZeej                 ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dS )RetrievAugLMMarginOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss.
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
        each vocabulary token.
    doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
        Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
        `question_encoder_last_hidden_state`.
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

        Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
        (see `past_key_values` input) to speed up sequential decoding.
    retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
        Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
        the `doc_scores`.
    retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
        The indexes of the embedded documents retrieved by the retriever.
    context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
    context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
        retriever.
    question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
        model.
    question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
    question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
    generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
    generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
    generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
        weighted average in the cross-attention heads.
    Nlosslogits
doc_scorespast_key_valuesretrieved_doc_embedsretrieved_doc_idscontext_input_idscontext_attention_mask"question_encoder_last_hidden_state.question_enc_hidden_statesquestion_enc_attentionsgenerator_enc_last_hidden_stategenerator_enc_hidden_statesgenerator_enc_attentionsgenerator_dec_hidden_statesgenerator_dec_attentionsgenerator_cross_attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   r	   r   r   
LongTensorr    r!   r"   r#   tupler$   r%   r&   r'   r(   r)   r*        x/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/transformers/models/rag/modeling_rag.pyr   r   %   s        D DL )-D(5$
%,,,*.FHU&'....2J*+222'+OXe_+++8<(5#45<<<48x 0188848x 018889=HU%56===FJ&1B(CJJJJNu/@#/E)F GNNNGKXeE,=s,B&CDKKKCG#Xe.?%@GGGKO%0A30F*G!HOOOHLhuU->-C'DELLLKO%0A30F*G!HOOOHLhuU->-C'DELLLJNu/@#/E)F GNNNNNr5   r   c                      e Zd ZU dZdZeej                 ed<   dZ	eej                 ed<   dZ
ee         ed<   dZeej                 ed<   dZeej                 ed<   dZeej                 ed<   dZeej                 ed	<   dZeej                 ed
<   dZeeej        df                  ed<   dZeeej        df                  ed<   dZeej                 ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dZeeej        df                  ed<   dS )RetrievAugLMOutputa"  
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
        each vocabulary token.
    doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
        Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
        `question_encoder_last_hidden_state`.
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

        Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
        (see `past_key_values` input) to speed up sequential decoding.
    retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
        Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
        the `doc_scores`.
    retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
        The indexes of the embedded documents retrieved by the retriever.
    context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
    context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
        Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
        retriever.
    question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
        model.
    question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
    question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
        Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
    generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
    generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
        shape `(batch_size, sequence_length, hidden_size)`.

        Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
    generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
        average in the self-attention heads.
    generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
        sequence_length)`.

        Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
        weighted average in the cross-attention heads.
    Nr   r   r   r   r   r    r!   r"   .r#   r$   r%   r&   r'   r(   r)   r*   )r+   r,   r-   r.   r   r   r/   r0   r1   r   r   r	   r   r   r2   r    r!   r"   r#   r3   r$   r%   r&   r'   r(   r)   r*   r4   r5   r6   r8   r8      s        B BH +/FHU&'....2J*+222'+OXe_+++8<(5#45<<<48x 0188848x 018889=HU%56===FJ&1B(CJJJJNu/@#/E)F GNNNGKXeE,=s,B&CDKKKCG#Xe.?%@GGGKO%0A30F*G!HOOOHLhuU->-C'DELLLKO%0A30F*G!HOOOHLhuU->-C'DELLLJNu/@#/E)F GNNNNNr5   r8   a  
    RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
    Tasks](https://huggingface.co/papers/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.

    RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
    generator, the encoder and generator are trainable while the retriever is just an indexed dataset.
    c            
       n    e Zd ZU eed<   dZdZdZe	 	 	 d
de	e
         de	e
         dedefd	            ZdS )RagPreTrainedModelconfigragTN.question_encoder_pretrained_model_name_or_path'generator_pretrained_model_name_or_path	retrieverreturnc                    d |                                 D             }d |                                 D             }|D ]}|d|z   = 	|D ]}|d|z   = 	|                    dd          }|D|
J d            dd	lm}	 d
|vr ddlm}
  |
j        |fi |ddi\  }}||d
<    |	j        |fi |}|                    dd          }|D|
J d            ddlm} d
|vr ddlm}
  |
j        |fi |ddi\  }}||d
<    |j        |fi |}|                    d
          }|t          j
        |j        |j        fi |} | ||||          S )a  
        Instantiates an question encoder and a generator from one or two base classes of the library from pretrained
        model checkpoints.

        The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
        the model, you need to first set it back in training mode with `model.train()`.

        Params:
            question_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
                Information necessary to initiate the question encoder. Can be either:

                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
                    - A path to a *directory* containing model weights saved using
                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
                    - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
                      this case, `from_tf` should be set to `True` and a configuration object should be provided as
                      `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
                      PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
                Information necessary to initiate the generator. Can be either:

                    - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
                    - A path to a *directory* containing model weights saved using
                      [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
                    - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
                      this case, `from_tf` should be set to `True` and a configuration object should be provided as
                      `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
                      PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.

            model_args (remaining positional arguments, *optional*):
                All remaining positional arguments will be passed to the underlying model's `__init__` method.
            retriever ([`RagRetriever`], *optional*):
                The retriever to use.
            kwwargs (remaining dictionary of keyword arguments, *optional*):
                Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
                `output_attentions=True`).

                - To update the question_encoder configuration, use the prefix *question_encoder_* for each
                  configuration parameter.
                - To update the generator configuration, use the prefix *generator_* for each configuration parameter.
                - To update the parent model configuration, do not use a prefix for each configuration parameter.

                Behaves differently depending on whether a `config` is provided or automatically loaded.

        Example:

        ```python
        >>> from transformers import RagModel

        >>> # initialize a RAG from two pretrained models.
        >>> model = RagModel.from_pretrained_question_encoder_generator(
        ...     "facebook/dpr-question_encoder-single-nq-base", "google-t5/t5-small"
        ... )
        >>> # saving model after fine-tuning
        >>> model.save_pretrained("./rag")
        >>> # load fine-tuned model
        >>> model = RagModel.from_pretrained("./rag")
        ```c                 n    i | ]2\  }}|                     d           |t          d           d         |3S )question_encoder_N
startswithlen.0argumentvalues      r6   
<dictcomp>zQRagPreTrainedModel.from_pretrained_question_encoder_generator.<locals>.<dictcomp>2  sW     #
 #
 #
%""#677#
S,--//0%#
 #
 #
r5   c                 n    i | ]2\  }}|                     d           |t          d           d         |3S )
generator_NrD   rG   s      r6   rK   zQRagPreTrainedModel.from_pretrained_question_encoder_generator.<locals>.<dictcomp>8  sU     
 
 
%""<00
S&&(()5
 
 
r5   rC   rM   modelNznIf `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to be defined   	AutoModelr;   )
AutoConfigreturn_unused_kwargsTzqIf `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has to be definedAutoModelForSeq2SeqLM)question_encoder	generatorr;   r?   )itemspopauto.modeling_autorQ   auto.configuration_autorR   from_pretrainedrU   getr   'from_question_encoder_generator_configsr;   )clsr=   r>   r?   kwargskwargs_question_encoderkwargs_generatorkeyrV   rQ   rR   question_encoder_configrW   rU   generator_configr;   s                   r6   *from_pretrained_question_encoder_generatorz=RagPreTrainedModel.from_pretrained_question_encoder_generator   s   H#
 #
#)<<>>#
 #
 #

 
#)<<>>
 
 
 + 	2 	2C*S011# 	+ 	+C|c)**
 366wEE#AMM NMM 766666666@@@@@@C]:C]BD D-D D *.D D D@')@
 5L'18y8>   BY    %(($77	:FF! GFF CBBBBB///@@@@@@5OZ5O;6 6?O6 6fj6 6 62 "2 .> *=-=7 ;K I
 H%%>F ')9 =C F s$4	RXdmnnnnr5   )NNN)r+   r,   r-   r   r1   base_model_prefix_supports_flash_attn_supports_sdpaclassmethodr   strr   r   rf   r4   r5   r6   r:   r:      s          N IMAE"&	Jo Jo8@Jo 2:#Jo  	Jo 
Jo Jo Jo [Jo Jo Jor5   r:   c            "           e Zd Z	 	 	 	 ddee         dee         dee         dee         f fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddee	j
                 dee	j                 d	eeee	j                                   d
ee	j
                 dee	j                 dee         dee	j                 dee	j
                 dee	j
                 dee         dee         dee         dee         dee         deee	j                 ef         fd            Z xZS )RagModelNr;   rV   rW   r?   c                 R   |||
J d            |t          j        |j        |j        fi |}n*t          || j                  sJ d| d| j                     t                                          |           | ddlm} |	                    |j
                  }| ddlm} |	                    |j                  }|| _        | j        <t          |t                    s J dt          | j                   d	            || _        || _
        || _        d| _        d
| _        dS )  
        question_encoder (`PreTrainedModel`, *optional*):
            The model responsible for encoding the question into hidden states for retrieval.
        generator (`PreTrainedModel`, *optional*):
            The model responsible for generating text based on retrieved documents.
        retriever (`RagRetriever`, *optional*):
            The component responsible for retrieving documents from a knowledge base given the encoded question.
        NzQEither a configuration or an question_encoder and a generator has to be provided.zconfig: z has to be of type rO   rP   rT   z`self.retriever` is of type z&, but should be of type `RagRetriever`F)r   r^   r;   
isinstanceconfig_classsuper__init__rZ   rQ   from_configrV   rU   rW   r?   r   typectx_encodercontext_encoder_training)	selfr;   rV   rW   r?   r`   rQ   rU   	__class__s	           r6   rs   zRagModel.__init__~  s     !&6&ByG\G\_ H]G\] >F ')9 =C FF fd&788ss:sV:s:s`d`q:s:sss8   #666666(44V5LMMBBBBBB-99&:JKKI">%i66  ktDN/C/Ckkk 6 'DN 0"(-%%%r5   	input_idsattention_maskencoder_outputsdecoder_input_idsdecoder_attention_maskr   r   r    r!   	use_cacheoutput_attentionsoutput_hidden_statesoutput_retrievedn_docsr@   c                 
   ||n| j         j        }|
|
n| j         j        }
||n| j         j        }||n| j         j        }||n| j         j        }| j        duo|du p|	du p|du o|du }||r\|                     ||d          }|d         }|                     ||                                	                    dt          j                                                  | j        j         j        |d          }| j        r|d	         |d
         |d         |d         |d         |d         f\  }}	}}}}|	                    |          }|		                    |          }	|	                    |          }|	                    |          }|                     ||d          j        }|                    d||j        d                   }t          j        |                    d          |                    dd                                        d          }n|d	         |d
         |d         |d         f\  }}	}}|	                    |          }|	                    |          }|		                    |          }	t          j        |                    d          |                    dd                                        d          }n$|
J d            |	
J d            |
J d            |
J d            |j        d         |z  dk    sJ d| d|j        d          d            ||                    |d          }||                    |d          }|                     ||	|||||
|d	  	        }|sd}d}d}d}d}n|j        }|j        }|r|sd}d}	d}d}t7          d)i d|j        d|d|j        d	|d
|	d|d|d |d!|d"|d#|j        d$|j        d%|j         d&|j!        d'|j"        d(|j#        S )*ay  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
            which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
            obtain the indices.

            [What are input IDs?](../glossary#input-ids)
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
            Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
            *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
            sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
            generator's encoder.

            Used by the ([`RagModel`]) model during decoding.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Provide for generation tasks. `None` by default, construct as per instructions for the generator model
            you're using with your RAG instance.
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,  target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
            has to be provided to the forward pass. `doc_scores` can be computed via
            `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
        context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
            the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
        context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
            provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
        output_retrieved (`bool`, *optional*):
            Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
            `context_attention_mask`. See returned tensors for more detail.
        n_docs (`int`, *optional*):
            The number of documents to retrieve.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, RagModel
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)

        >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
        >>> outputs = model(input_ids=inputs["input_ids"])
        ```NT)r{   return_dictr   cpudevicedtypeptprefixr   return_tensorsr    r!   r   tokenized_doc_idstokenized_doc_attention_maskdoc_idsr   rO   zMake sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.z^Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function.M The first dimension of `context_input_ids` should be a multiple of `n_docs`=	, but is .dim)	rz   r{   r|   r}   r~   r   r   r   r   Nr   r   r   r   r"   r#   r$   r%   r&   r'   r(   r)   r*   r4   )$r;   r   r   r   r   r   r?   rV   detachtor/   float32numpyrW   r   rw   rv   pooler_outputviewshapebmm	unsqueeze	transposesqueezerepeat_interleavehidden_states
attentionsr8   r   r   encoder_last_hidden_stateencoder_hidden_statesencoder_attentionsdecoder_hidden_statesdecoder_attentionscross_attentions)rx   rz   r{   r|   r}   r~   r   r   r    r!   r   r   r   r   r   has_to_retrievequestion_enc_outputsr"   retriever_outputsr   retrieved_doc_input_idsretrieved_doc_attention_maskr   gen_outputsr#   r$   s                             r6   forwardzRagModel.forward  s   R "-4;3E!*!6IIDK<Q	1B1N--TXT_Tq$8$D  $+Jj 	 0@/K++QUQ\Qm N$& ("d*b.D.LbPZ^bPb(4' 	 " L'+'<'<n$ (= ( ($ 6J!5L2$(NN6==??BB%W\WdBeekkmm>07!#' %3 % %! 0 2! **=>)*BC)*@A)*=>)*HI))4).,/4) ):(<(<Y(G(G%-C-F-Fy-Q-Q*.E.H.H.S.S+3O3R3RS\3]3]0+/+;+;/@\jn ,< , ,# ) ,@+D+DF$F$LQ$O, ,(
 "':DDQGGI]IgIghiklImIm" "gajj J **=>)*BC)*@A))4	jf%'=?SUf ,@+B+BCe+f+f((9(<(<Y(G(G%-C-F-Fy-Q-Q* "':DDQGGI]IgIghiklImIm" "gajj J )44P 544 .99T :99 "--J .--
 %%l &%%  #f,222.\b . .!'*. . . 322 ( 1 C CFPQ C R R!-%;%M%MfZ[%M%\%\"nn'1+/#9+/ % 

 

  	F15.)-&&*##'  $)=)K&&:&E# 	%&6 	% '%)"#'  $! 
 
 
%%
!z
 (77
 0/	

 $:#9
 "6!5
 0/
 0R/Q
 (B'A
 %<$;
 -8,Q,Q
 )4(I(I
 &1%C%C
 )4(I(I
 &1%C%C
  (3'C'C!
 	
r5   NNNN)NNNNNNNNNNNNNN)r+   r,   r-   r   r   r   r   rs   r   r/   r2   Tensorr3   r0   
BoolTensorr	   boolintr   r8   r   __classcell__ry   s   @r6   rm   rm   |  s        .26:/3,00. 0.)*0. #?30. O,	0.
 L)0. 0. 0. 0. 0. 0.d  1515EI8<=A+/268<=A$(,0/3+/ $d
 d
E,-d
 !.d
 "%e.?(@"AB	d

 $E$45d
 !))9 :d
 "%d
 U./d
 $E$45d
 !))9 :d
 D>d
 $D>d
 'tnd
 #4.d
 d
  
uU\"$66	7!d
 d
 d
 ^d
 d
 d
 d
 d
r5   rm   zu
    A RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass.
    c            &           e Zd Z	 	 	 	 d(dee         dee         dee         dee         f fdZdefdZdefd	Z	e
	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d)d
eej                 deej                 deeeej                                   deej                 deej                 dee         deej                 deej                 deej                 dee         dee         dee         dee         dee         dee         deej                 dee         def$d            Zed             Zed             Zed             Z ej                    	 	 	 	 	 	 	 	 	 d*d
eej                 deej                 deej                 deej                 deej                 d ee         d!ee         d"ee         dee         dej        fd#            Z	 d+d&Zed'             Z xZS ),RagSequenceForGenerationNr;   rV   rW   r?   c                     |||
J d            |t          j        |j        |j        fi |}t                                          |           t          ||||          | _        dS ro   NzHEither a configuration or an encoder and a generator has to be provided.)r;   rV   rW   r?   r   r^   r;   rr   rs   rm   r<   rx   r;   rV   rW   r?   r`   ry   s         r6   rs   z!RagSequenceForGeneration.__init__  s      !&6&ByG\G\V H]G\] >F ')9 =C F 	    6<LXamvwwwr5   c                     || j         _        d S r   r<   r?   rx   r?   s     r6   set_retrieverz&RagSequenceForGeneration.set_retriever      &r5   rv   c                 6    d| j         _        || j         _        d S NTr<   rw   rv   rx   rv   s     r6    set_context_encoder_for_trainingz9RagSequenceForGeneration.set_context_encoder_for_training      ,0)*r5   rz   r{   r|   r}   r~   r   r    r!   r   r   r   r   r   exclude_bos_scorereduce_losslabelsr   r@   c                 :   ||n| j         j        }||n| j         j        }||n| j         j        }|||}d}
|                     ||||||||	||
||||          }d}|0|                     |j        |j        ||| j         j        ||          }t          di d|d|j        d|j        d|j
        d	|j        d
|j        d|j        d|j        d|j        d|j        d|j        d|j        d|j        d|j        d|j        d|j        d|j        S )a3  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
            which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
            obtain the indices.

            [What are input IDs?](../glossary#input-ids)
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
            Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
            *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
            sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
            generator's encoder.

            Used by the ([`RagModel`]) model during decoding.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Provide for generation tasks. `None` by default, construct as per instructions for the generator model
            you're using with your RAG instance.
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,  target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
            the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
        context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
            provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
        doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
            has to be provided to the forward pass. `doc_scores` can be computed via
            `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
        output_retrieved (`bool`, *optional*):
            Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
            `context_attention_mask`. See returned tensors for more detail.
        exclude_bos_score (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
            the loss.
        reduce_loss (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
            operation.
        n_docs (`int`, *optional*):
            The number of documents to retrieve.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, RagSequenceForGeneration
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

        >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
        >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
        >>> input_ids = inputs["input_ids"]
        >>> labels = targets["input_ids"]
        >>> outputs = model(input_ids=input_ids, labels=labels)

        >>> # or use retriever separately
        >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
        >>> # 1. Encode
        >>> question_hidden_states = model.question_encoder(input_ids)[0]
        >>> # 2. Retrieve
        >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
        >>> doc_scores = torch.bmm(
        ...     question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
        ... ).squeeze(1)
        >>> # 3. Forward to generator
        >>> outputs = model(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ...     decoder_input_ids=labels,
        ... )
        ```NFrz   r{   r|   r}   r~   r    r!   r   r   r   r   r   r   r   )r   epsilonr   r   r   r   r   r   r    r!   r   r   r"   r#   r$   r%   r&   r'   r(   r)   r*   r4   )r;   r   r   r   r<   get_nllr   r   label_smoothingr   r   r    r!   r   r   r"   r#   r$   r%   r&   r'   r(   r)   r*   )rx   rz   r{   r|   r}   r~   r   r    r!   r   r   r   r   r   r   r   r   r   r`   outputsr   s                        r6   r   z RagSequenceForGeneration.forward  s   N "-4;3E1B1N--TXT_Tq%0%<kk$+BY ($*!I(()+/#9/#9!+/!5-  
 
" <<"!'3"3    D ( 
 
 

>>
 ))
 $33	

 &77
 $+#A#A
 ")!=!=
 &77
 07/Y/Y
 (/'I'I
 %,$C$C
 -4,S,S
 )0(K(K
 &-%E%E
 )0(K(K
  &-%E%E!
" (/'I'I#
 	
r5   c                     | j         j        S r   r   rx   s    r6   r?   z"RagSequenceForGeneration.retrieverc      x!!r5   c                     | j         j        S r   r<   rW   r   s    r6   rW   z"RagSequenceForGeneration.generatorg  r   r5   c                     | j         j        S r   r<   rV   r   s    r6   rV   z)RagSequenceForGeneration.question_encoderk      x((r5   do_deduplicationnum_return_sequences	num_beamsc
                    |	|	n| j         j        }	||n| j         j        }||n| j         j        }||n| j         j        }||
J d            | j        ||                     ||          d         }|                     ||                                                    dt          j
                                                  | j        j         j        |	d          d	         }|                    |          }g }||
d
<   ||
d<   d|
d<   ||j        d         n|j        d         |	z  }t          |          D ]r}|||	z  |dz   |	z           } | j        j        |fi |
}|r=t          j        t%          d |D                                                                 }|j        d         }|0|||dz                                |d          } | ||d          }n|
J d            |
J d            |                    |d          }|||	z  |dz   |	z           }|                    |d          }|||dz   ddf         }|                    |d          } | ||||d          }|d                              |          d         }|                    ||                    t|                     || j         j        j                  S )a  
        Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
        for more information on how to set other generate input parameters.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                The sequence used as a prompt for the generation. If `input_ids` is not passed, then
                `context_input_ids` has to be provided.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
                retriever.
            context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model is not initialized with a `retriever` or `input_ids` is not given, `context_input_ids` and
                `context_attention_mask` have to be provided to the forward pass. They are returned by
                [`~RagRetriever.__call__`].
            doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
                Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
                `question_encoder_last_hidden_state`.

                If the model is not initialized with a `retriever` or `input_ids` is not given, `doc_scores` has to be
                provided to the forward pass. `doc_scores` are returned by [`~RagRetriever.__call__`].
            do_deduplication (`bool`, *optional*):
                Whether or not to deduplicate the generations from different context documents for a given input. Has
                to be set to `False` if used while training with distributed backend.
            num_return_sequences(`int`, *optional*, defaults to 1):
                The number of independently computed returned sequences for each element in the batch. Note that this
                is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
                where we set `num_return_sequences` to `num_beams`.
            num_beams (`int`, *optional*, defaults to 1):
                Number of beams for beam search. 1 means no beam search.
            n_docs (`int`, *optional*, defaults to `config.n_docs`)
                Number of documents to retrieve and/or number of documents for which to generate an answer.
            kwargs (`dict[str, Any]`, *optional*):
                Additional kwargs will be passed to [`~generation.GenerationMixin.generate`].

        Return:
            `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
            sequences. The second dimension (sequence length) is either equal to `max_length` or shorter if all batches
            finished early due to the `eos_token_id`.
        Nz= At least one of input_ids or context_input_ids must be givenr{   r   r   r   r   r   r    r   r   r{   r   c                 R    i | ]$}t          |                                          |%S r4   )rk   tolist)rH   ks     r6   rK   z5RagSequenceForGeneration.generate.<locals>.<dictcomp>  s(    4b4b4bAS__a4b4b4br5   T)r   r   zMake sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.zMake sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a retriever using the `set_retriever(...)` function.)r    r!   r   r   r   r   )pad_token_id)r;   r   r   r   r   r?   rV   r   r   r/   r   r   rW   r   r   rangegeneratestacklistvaluesrepeattopkappend_cat_and_padr   )rx   rz   r{   r    r!   r   r   r   r   r   model_kwargsnum_doc_return_sequencesquestion_hidden_stateshypos
batch_sizeindexgenerator_input_idsoutput_sequencesnum_candidatesnew_input_idsr   individual_input_idsindividual_attention_maskindividual_doc_scorestop_cand_indss                            r6   r   z!RagSequenceForGeneration.generateo  s   B "-4;3E/?/K++QUQ\Qm$8$D  $+Jj 	! "+!6IIDK<Q	$(9(E(EK )F(EE >%*;*C%)%:%:9Uc%:%d%def%g" $&--//22%u}2UU[[]]~,3# !/ ! ! "!# !2 4 4Y ? ?$-[!/8+,)-%&+4+@Y_Q''FWF]^_F`djFj
:&& 3	: 3	:E"3EFNeaiSYEY4Y"Z6t~6#         n#(;t4b4bQa4b4b4b4i4i4k4k/l/l#m#m -3N
 $ )%%!)*; < C CNTU V V$}5EY]^^^-99T :99 "--J .--
 (;'A'A"A( ($ -C56>UZ]^U^bhThCh,i),E,L,L^]^,_,_)(25EAI3F3I(J%(=(D(D^UV(W(W%$&:+D4+&*   &fo-334LMMaPM LL)-89999  T[5J5W XXXr5   F        c                     t          j        d d dd f                             j        d         d                               j        j        j                  gd          ||n j        j        } j        j	        p j        j        j	        }|d uo0d d df         
                    |                                          }	 fd}
t          j                            |d                              |j        d         |z  |d|                    d                    }t          j                            |d                              d                              d          }|d d d d d dd d f         }|d d d d ddd d f         }|d d d d dd d d f         }t          j        |||z   |gd          }                    d                              d                              d|dd                                          |                                k    sJ |                    d          }|                    dd	          } |
||          \  }}|r&|	r$|d d d d dd f                             d          n|                    d          }|                    d          }|                    d          }|                    d          }| }| }|r(|                                }|                                }||                    d          z  }d
|z
  |z  ||z  z   }|S )Nr   r   c                                          j        j        j                  }|                                r,|                     |d           |                    |d           |                     d          |                    d          fS Nr   r   eqr;   rW   r   anymasked_fill_r   ll
smooth_objpad_maskrx   targets      r6   
_mask_padsz4RagSequenceForGeneration.get_nll.<locals>._mask_pads  y    yy!6!CDDH||~~ 7#...''#666::b>>:#5#5b#9#999r5   r   r   rO   r   r   Tr   keepdim      ?)r/   catnewr   fill_r;   rW   r   r   bos_token_idr  allr   
functionallog_softmaxr   sizer   r   r   gathersum	logsumexp)rx   
seq_logitsr   r  r   r   r   r   r  use_bosr	  seq_logprobsdoc_logprobsfirst_token_scoressecond_token_scores	remainderrag_logprobsr  r  nll_losssmooth_losseps_ir   s   `  `                   r6   r   z RagSequenceForGeneration.get_nll  so    AAAqrrE]FJJv|A::@@AVAcddegh
 
 "-4;3E {/U4;3H3Ud*Rvaaad||/L/L/P/P/R/R	: 	: 	: 	: 	: 	: }000DDIIQ6)62zr7J7J
 
 }000CCMMbQQ[[\^__ *!!!QQQAAA+6*111aaa1aaa<8 AAAqrr111-	y"46IL6XZc!djklll !!!$$..r2299!VQJJzz|||//111111  Rv 66!%%"d%;;
#B
33J %6P'PR111abb\a   rvvayy^^A&&
\\!__))!,,
3!k 	,||~~H%//++K,++B///g)EK,??r5   c                 6   | d                              t          d | D                       t          d | D                                                     |          }d}| D ]6}|||||j        d         z   d |j        d         f<   ||j        d         z  }7|S )Nr   c              3   0   K   | ]}|j         d          V  dS )r   Nr   rH   ts     r6   	<genexpr>z8RagSequenceForGeneration._cat_and_pad.<locals>.<genexpr>C  s(      #@#@1AGAJ#@#@#@#@#@#@r5   c              3   0   K   | ]}|j         d          V  dS )r   Nr'  r(  s     r6   r*  z8RagSequenceForGeneration._cat_and_pad.<locals>.<genexpr>C  s)      EbEbUVagajEbEbEbEbEbEbr5   r   )r  r  maxr  r   )tensorsr   outputindr)  s        r6   r   z%RagSequenceForGeneration._cat_and_padA  s    #@#@#@#@#@ @ @#EbEbZaEbEbEbBbBbcciijvww 	 	A;<F3qwqz))<QWQZ<78171:CCr5   r   NNNNNNNNNNNNNNNNN)	NNNNNNNNN)Fr   FN) r+   r,   r-   r   r   r   r   rs   r   r   r   r/   r2   r   r3   r   r	   r0   r   r   r   r   propertyr?   rW   rV   no_gradr   r   staticmethodr   r   r   s   @r6   r   r     s        .26:/3,0x x)*x #?3x O,	x
 L)x x x x x x:'| ' ' ' '+O + + + +  1515@D8<=A+/8<=A26$(,0/3+/,0&*-1 $%^
 ^
E,-^
 !.^
 "%el(;"<=	^

 $E$45^
 !))9 :^
 "%^
 $E$45^
 !))9 :^
 U./^
 D>^
 $D>^
 'tn^
 #4.^
 $D>^
  d^!^
" )*#^
$ %^
( 
")^
 ^
 ^
 ^^
@ " " X" " " X" ) ) X) U]__ 15598<=A26+/.2#' $TY TYE,-TY !!12TY $E$45	TY
 !))9 :TY U./TY #4.TY 'smTY C=TY TY 
	TY TY TY _TYn os9 9 9 9v   \    r5   r   zo
    A RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
    c            &       R    e Zd Z	 	 	 	 d0dee         dee         dee         dee         f fdZdefdZdefd	Z		 	 	 	 	 	 d1d
Z
ed             Zed             Zed             Zed             Zd2dZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d3deej                 deej                 deeeej                                   deej                 deej                 dee         deej                 deej                 deej                 dee         dee         dee         dee         dee         dee         deej                 d ee         d!ef$d"            Z ej                    dddddddd e             e             f
deej                 deej                 deej                 deej                 deej                 d ee         d#ee!         d$ee"eej        ge#e         f                  d%ee         d&ee          d!ej        fd'            Z$d( Z%d) Z&d* Z'd+ Z(d2d,Z)d4d/Z* xZ+S )5RagTokenForGenerationNr;   rV   rW   r?   c                     |||
J d            |t          j        |j        |j        fi |}t                                          |           t          ||||          | _        dS r   r   r   s         r6   rs   zRagTokenForGeneration.__init__Q  s      !&6&ByG\G\V H]G\] >F ')9 =C F 	    6<LXamvwwwr5   c                     || j         _        d S r   r   r   s     r6   r   z#RagTokenForGeneration.set_retrievero  r   r5   rv   c                 6    d| j         _        || j         _        d S r   r   r   s     r6   r   z6RagTokenForGeneration.set_context_encoder_for_trainingr  r   r5   c           
      :    ||d d dd f         }d ||||||d|d	S )Nr   T)	rz   r|   r   r!   r}   r   r   do_marginalizer   r4   )	rx   r}   r   r{   r   r|   r   r   r`   s	            r6   prepare_inputs_for_generationz3RagTokenForGeneration.prepare_inputs_for_generationv  sM     & 1!!!RSS& 9 .$&4!2.""

 

 
	
r5   c                     | j         j        S r   r   r   s    r6   r?   zRagTokenForGeneration.retriever  r   r5   c                     | j         j        S r   r   r   s    r6   rW   zRagTokenForGeneration.generator  r   r5   c                     | j         j        S r   r   r   s    r6   rV   z&RagTokenForGeneration.question_encoder  r   r5   c                     d d}| D ]"}|t          fd|D                       fz  }#t          | t                    rt          j        |          }|S )zeReorders cache for generation. BART-inspired but we need to take care of the extra dimension for docsc                     | j         d         |j         d         z  } | j        d|g| j         dd          R  } |                     d|          }  | j        dg| j         dd          R  }|S )Nr   r   r   rO   )r   r   index_select)r   	new_orderr   results       r6   _reorder_stackedz>RagTokenForGeneration._reorder_cache.<locals>._reorder_stacked  s    "(+yq/AAF.M.r6TM<OPQPRPR<STTTM)66q)DDM']'E]-@-DEEEFMr5   r4   c              3   `   K   | ](} |                     |j                            V  )d S r   )r   r   )rH   
past_staterD  beam_idxs     r6   r*  z7RagTokenForGeneration._reorder_cache.<locals>.<genexpr>  sA      ppWa&&z8;;z?P3Q3QRRppppppr5   )r3   rp   r
   from_legacy_cache)r   rG  reordered_past
layer_pastrD  s    `  @r6   _reorder_cachez$RagTokenForGeneration._reorder_cache  s    	 	 	 ) 	 	Jpppppeoppppp NN o':;; 	S0B>RRNr5   c                    ||n| j         j        }t          j                            |d                              |j        d         |z  |d|                    d                    }t          j        |d          }||	                    d          	                    d          z   }t          j
        |d          S )Nr   r   r   r   )r;   r   r   r  r  r   r   r  r/   r   r  )rx   r  r   r   r  r  log_prob_sums          r6   marginalizez!RagTokenForGeneration.marginalize  s    !-4;3E }000DDIIQ6)62zr7J7J
 
 (;;;#l&<&<R&@&@&J&J2&N&NN|3333r5   rz   r{   r|   r}   r~   r   r    r!   r   r   r   r   r   r:  r   r   r   r@   c                    ||n| j         j        }||n| j         j        }||n| j         j        }|||}d}
|                     ||||||||	||
||||          }d}|j        }|3|J |                     |j        |j        ||| j         j        |          }|r| 	                    ||j        |          }t          di d|d|d|j        d|j        d	|j        d
|j        d|j        d|j        d|j        d|j        d|j        d|j        d|j        d|j        d|j        d|j        d|j        S )a  
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
            which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
            obtain the indices.

            [What are input IDs?](../glossary#input-ids)
        encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
            Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
            *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
            sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
            generator's encoder.

            Used by the ([`RagModel`]) model during decoding.
        decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
            Provide for generation tasks. `None` by default, construct as per instructions for the generator model
            you're using with your RAG instance.
        decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,  target_sequence_length)`, *optional*):
            Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
            be used by default.
        context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
            Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
            the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
        context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
            Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
            retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
            provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
        doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
            Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
            `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
            has to be provided to the forward pass. `doc_scores` can be computed via
            `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
        output_retrieved (`bool`, *optional*):
            Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
            `context_attention_mask`. See returned tensors for more detail.
        do_marginalize (`bool`, *optional*):
            If `True`, the logits are marginalized over all documents by making use of
            `torch.nn.functional.log_softmax`.
        reduce_loss (`bool`, *optional*):
            Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
            operation.
        n_docs (`int`, *optional*):
            The number of documents to retrieve.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, RagRetriever, RagTokenForGeneration
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq")
        >>> retriever = RagRetriever.from_pretrained(
        ...     "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True
        ... )
        >>> # initialize with RagRetriever to do everything in one forward call
        >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)

        >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
        >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
        >>> input_ids = inputs["input_ids"]
        >>> labels = targets["input_ids"]
        >>> outputs = model(input_ids=input_ids, labels=labels)

        >>> # or use retriever separately
        >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
        >>> # 1. Encode
        >>> question_hidden_states = model.question_encoder(input_ids)[0]
        >>> # 2. Retrieve
        >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
        >>> doc_scores = torch.bmm(
        ...     question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
        ... ).squeeze(1)
        >>> # 3. Forward to generator
        >>> outputs = model(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ...     decoder_input_ids=labels,
        ... )

        >>> # or directly generate
        >>> generated = model.generate(
        ...     context_input_ids=docs_dict["context_input_ids"],
        ...     context_attention_mask=docs_dict["context_attention_mask"],
        ...     doc_scores=doc_scores,
        ... )
        >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
        ```NFr   )r   r   r   r   r   r   r   r    r!   r   r   r"   r#   r$   r%   r&   r'   r(   r)   r*   r4   )r;   r   r:  r   r<   r   r   r   r   rN  r   r   r    r!   r   r   r"   r#   r$   r%   r&   r'   r(   r)   r*   )rx   rz   r{   r|   r}   r~   r   r    r!   r   r   r   r   r   r:  r   r   r   r`   r   r   r   s                         r6   r   zRagTokenForGeneration.forward  s+   ^ "-4;3E+9+E4;Ke%0%<kk$+BY ($*!I(()+/#9/#9!+/!5-  
 
" $000<<"'3    D  	J%%fg.@&IIF' 
 
 

6
 ))
 $33	

 &77
 $+#A#A
 ")!=!=
 &77
 07/Y/Y
 (/'I'I
 %,$C$C
 -4,S,S
 )0(K(K
 &-%E%E
 )0(K(K
  &-%E%E!
" (/'I'I#
 	
r5   generation_configprefix_allowed_tokens_fnlogits_processorstopping_criteriac           	         || j         }t          j        |          } |j        d&i |}|                    dd          du}|                     ||           n| j        j        | j        2|/| 	                    ||          d         }|                     ||
                                                    dt          j                                                  | j        j        j        d          }|d	         |d
         |d         }}}|                    |          }|                    |          }|                    |          }t          j        |                    d          |                    dd                                        d          }|j        d         z  dk    sJ d d|j        d          d            |j        d         z  | j        j                                        } |||d          }t          j        |j        z  df|j        t          j        t9          |                                           j                  }|j        d         }|d         }d'fd	} |||j                  } |||j                  |d<   |                    |j        d          }||d<   ||d<   ||d<   |d<   |                      |||||	|j                  }| !                    ||
          }| "                    ||d|j        d         |j#        dz
             |j        dk    r7|j$        dk    rtK          d|j$         d            | j&        |f|||d!dd"|S |j        dk    r2|j$        |j        k    rtK          d#           | j'        |f|||d!d$|S tK          d%|j                   )(a  
        Implements RAG token decoding.

        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                The sequence used as a prompt for the generation. If `input_ids` is not passed, then
                `context_input_ids` has to be provided.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
                Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
                retriever.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
                Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
                `question_encoder_last_hidden_state`.

                If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
                forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
            n_docs (`int`, *optional*, defaults to `config.n_docs`)
                Number of documents to retrieve and/or number of documents for which to generate an answer.
            generation_config (`~generation.GenerationConfig`, *optional*):
                The generation configuration to be used as base parametrization for the generation call. `**kwargs`
                passed to generate matching the attributes of `generation_config` will override them. If
                `generation_config` is not provided, the default will be used, which has the following loading
                priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
                configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
                default values, whose documentation should be checked to parameterize generation.
            prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
                If provided, this function constraints the beam search to allowed tokens only at each step. If not
                provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID
                `batch_id`. It has to return a list with the allowed tokens for the next generation step conditioned on
                the previously generated tokens `inputs_ids` and the batch ID `batch_id`. This argument is useful for
                constrained generation conditioned on the prefix, as described in [Autoregressive Entity
                Retrieval](https://huggingface.co/papers/2010.00904).
            logits_processor (`LogitsProcessorList`, *optional*):
                Custom logits processors that complement the default logits processors built from arguments and a
                model's config. If a logit processor is passed that is already created with the arguments or a model's
                config an error is thrown.
            stopping_criteria (`StoppingCriteriaList`, *optional*):
                Custom stopping criteria that complement the default stopping criteria built from arguments and a
                model's config. If a stopping criteria is passed that is already created with the arguments or a
                model's config an error is thrown.
            kwargs (`dict[str, Any]`, *optional*):
                Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
                forwarded to the `forward` function of the model.

        Return:
            `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
            sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches
            finished early due to the `eos_token_id`.
        Nr{   r   r   r   r   r   r   r    r!   r   r   rO   r   r   r   T)rz   r{   r   )r   r   r   last_hidden_statec                    | d d d d f                              df| j        dd          z             } |                     |f| j        dd          z             } |                      |z  z  f| j        dd          z             S )Nr   r   )reshaper   expand)tensorr   r   r   s     r6   extend_enc_outputz9RagTokenForGeneration.generate.<locals>.extend_enc_output  s    D$M*22J63JV\Z[Z\Z\M]3]^^F]]J	6#BV\RSRTRTEU#UVVF>>:	#9F#B"Dv|TUTVTVGW"WXXXr5   )r   r   r   r|   r   )rP  input_ids_seq_lengthencoder_input_idsrQ  rR  r   )rP  rS  )generation_moder   max_cache_lengthz)num_return_sequences has to be 1, but is z when doing greedy search.F)rR  rS  rP  synced_gpusstreamerzA`num_return_sequences` has to be smaller or equal to `num_beams`.)rR  rS  rP  r_  uH   `num_beams` has to be an integer strictly superior to 0 (≥ 1), but is r4   r   )(rP  copydeepcopyupdater]   _prepare_special_tokensr;   r   r?   rV   r   r   r/   r   r   rW   r   r   r   r   r   r   r<   get_encoderfullr   decoder_start_token_idlongnext
parametersr   r   _get_logits_processor_get_stopping_criteria_prepare_cache_for_generation
max_lengthr   
ValueError_sample_beam_search)rx   rz   r{   r    r!   r   r   rP  rQ  rR  rS  r`   r   kwargs_has_attention_maskr   outr   encoderr|   r[  rU  rZ  pre_processorprepared_stopping_criteriar   s         `                 @r6   r   zRagTokenForGeneration.generatem  s    b $ $ 6 M*;<</(/99&99$0$4$45Et$L$LTX$X!$$%68QRRR "-4;3E >%*;*C%)%:%:9Uc%:%d%def%g"..&--//22%u}2UU[[]]~,3# !  C '(,-*+ 8L5 $8#:#:;Q#R#R  1 4 4Y ? ?%;%>%>y%I%I" #9#C#CA#F#FH\HfHfghjkHlHlmmuu J "'*V3999.\b . .!'*. . . :99 ',Q/69
($0022!',=NdrvwwwJ+55q94*))**1	
 
 
	  )r2+,?@	Y 	Y 	Y 	Y 	Y 	Y 	Y "3!23IUfUp!q!q!q/@/@):)D0
 0
 0
+,  112C2MST1UU
 &0\"*9&')?%&!'X22/!5/%=-# 3 
 
 &*%@%@/CT &A &
 &
" 	**  q).9A= 	+ 	
 	
 	
 &!++ 599 &@Q@f & & &    4<!."<"3!     (1,, 58I8SSS !deee$4$!."<"3!     x[l[vxx  r5   c                 2    |                      ||          }|S r   )rK  )rx   r   rG  s      r6   _temporary_reorder_cachez.RagTokenForGeneration._temporary_reorder_cacheC  s     --oxHHr5   c                 >    | j         j                                        S r   )r<   rW   get_input_embeddingsr   s    r6   rz  z*RagTokenForGeneration.get_input_embeddingsJ  s    x!66888r5   c                 >    | j         j                                        S r   )r<   rW   get_output_embeddingsr   s    r6   r|  z+RagTokenForGeneration.get_output_embeddingsM  s    x!77999r5   c                 @    | j         j                            |          S r   )r<   rW   set_output_embeddings)rx   new_embeddingss     r6   r~  z+RagTokenForGeneration.set_output_embeddingsP  s    x!77GGGr5   c                     || j         j        }|                    |j                  }|ddddf                                         |ddddf<   ||dddf<   |S )zCShift input ids one token to the right, and pad with start_token_idNr   r   r   )r;   rg  	new_zerosr   clone)rx   rz   start_token_idshifted_input_idss       r6   shift_tokens_rightz(RagTokenForGeneration.shift_tokens_rightS  su    !![?N%//	@@#,QQQV#4#:#:#<#<!!!QRR% "0!!!Q$  r5   Fr   c                 (    ||n j         j        }t          j        d d dd f                             j        d         d                               j         j        j                  gd           fd} 	                    |||          }
                    d                                          |                                k    sJ |                    d          }	|                    dd          }
 ||	|
          \  }	}
|	                    d          }	|
                    d          }
|	 }|
 }|r(|                                }|                                }||                    d          z  }d|z
  |z  ||z  z   }|S )	Nr   r   c                                          j        j        j                  }|                                r,|                     |d           |                    |d           |                     d          |                    d          fS r   r   r  s      r6   r	  z1RagTokenForGeneration.get_nll.<locals>._mask_padsc  r
  r5   r   r  Tr  r  )r;   r   r/   r  r  r   r  rW   r   rN  r   r   r  r  r  )rx   r  r   r  r   r   r   r	  r!  r  r  r"  r#  r$  r   s   `  `           r6   r   zRagTokenForGeneration.get_nll\  s   !-4;3EAAAqrrE]FJJv|A::@@AVAcddegh
 
	: 	: 	: 	: 	: 	: ''
JGG!!"%%zz|||//111111  Rv 66!%%"d%;;
#B
33JVVAYY^^A&&
3!k 	,||~~H%//++K,++B///g)EK,??r5   r   )NNNNNNr   r0  )Fr   N),r+   r,   r-   r   r   r   r   rs   r   r   r;  r1  r?   rW   rV   r3  rK  rN  r   r/   r2   r0   r3   r   r   r	   r   r   r   r   r2  r   r   r   r   r   r   rx  rz  r|  r~  r  r   r   r   s   @r6   r5  r5  K  s        .26:/3,0x x)*x #?3x O,	x
 L)x x x x x x<'| ' ' ' '+O + + + + 
 
 
 
: " " X" " " X" ) ) X)   \*	4 	4 	4 	4  156:@D8<=A+/8<=A26$(,0/3+/)-&*-1 $%j
 j
E,-j
 !!23j
 "%el(;"<=	j

 $E$45j
 !))9 :j
 "%j
 $E$45j
 !))9 :j
 U./j
 D>j
 $D>j
 'tnj
 #4.j
 !j
  d^!j
" )*#j
$ %j
( 
")j
 j
 j
 ^j
X U]__ 15598<=A26 $8<W[:M:M:O:O<P<P<R<RR RE,-R !!12R $E$45	R
 !))9 :R U./R R $$45R #+8S%,4Gc4R+S"TR ##67R $$89R 
	R R R _Rj  9 9 9: : :H H H! ! ! !" " " " " " " "r5   r5  )rm   r:   r   r5  ))r.   ra  dataclassesr   typingr   r   r   r/   r   cache_utilsr	   r
   configuration_utilsr   
generationr   r   r   r   modeling_outputsr   modeling_utilsr   utilsr   r   configuration_ragr   retrieval_ragr   
get_loggerr+   loggerr   r8   r:   rm   r   r5  __all__r4   r5   r6   <module>r     sP       ! ! ! ! ! ! , , , , , , , , , ,        5 5 5 5 5 5 5 5 3 3 3 3 3 3 f f f f f f f f f f f f + + + + + + - - - - - - , , , , , , , , ( ( ( ( ( ( ' ' ' ' ' ' 
	H	%	%   
WO WO WO WO WO{ WO WO  WOt TO TO TO TO TO TO TO  TOn    Qo Qo Qo Qo Qo Qo Qo  Qoh X
 X
 X
 X
 X
! X
 X
 X
v   
k k k k k1 k k 
k\   
n n n n n. n n 
nb b
a
ar5   