
     `i                    6   d dl Z d dlZd dlmZ d dlmZmZmZ d dlZd dl	m
c mZ d dlm
Z
 d dlmZmZmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZmZmZ ddl m!Z! ddl"m#Z#m$Z$m%Z% ddl&m'Z' ddl(m)Z)m*Z*  e$            rd dl+m,Z, d dl-m.Z. d dl/m0Z0 ne1Z. e%j2        e3          Z4 G d de          Z5	 	 dUdej6        dej6        deej6                 deej6                 de7ej6        ej6        ej6        e8eej6                 eej6                 f         f
dZ9dej6        dej6        de8de8dej6        f
dZ: G d  d!ej;        j<                  Z=	 	 dUd"eej6                 d#ee8         fd$Z> G d% d&e.          Z? G d' d(e
j@                  ZA G d) d*e
j@                  ZB G d+ d,e)          ZC	 dVd.d/d0ej6        dej6        d1ej6        deejD                 d2e7e8e8f         d3e8d4e8d5eeE         dee7ej6        ej6        f         e7ej6                 f         fd6ZFejG        fd.d/d0ej6        d7e?d"ej6        d#e8d2e7e8e8f         d3e8d4e8d8ejH        de7ej6                 fd9ZId.d/d0ej6        dej6        d1ej6        deejD                 d2e7e8e8f         d3e8d4e8de7ej6                 fd:ZJeIeFeJd;ZK G d< d/e
j@                  ZL G d= d>e          ZMe# G d? d@e!                      ZNe# G dA dBeN                      ZO G dC dDe
j@                  ZP e#dEF           G dG dHeN                      ZQ e#dIF           G dJ dKeN                      ZR e#dLF           G dM dNeN                      ZSe# G dO dPeN                      ZT e#dQF           G dR dSeN                      ZUg dTZVdS )W    N)nullcontext)LiteralOptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)PretrainedConfig)_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputMaskedLMOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)auto_docstringis_flash_attn_2_availablelogging)is_triton_available   )GemmaRotaryEmbeddingapply_rotary_pos_emb) flash_attn_varlen_qkvpacked_func)RotaryEmbedding)apply_rotaryc                        e Zd ZdZdZddiZdgZ	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 d!ded         f fdZ fd Z	 xZ
S )"ModernBertConfiga  
    This is the configuration class to store the configuration of a [`ModernBertModel`]. It is used to instantiate an ModernBert
    model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
    defaults will yield a similar configuration to that of the ModernBERT-base.
    e.g. [answerdotai/ModernBERT-base](https://huggingface.co/answerdotai/ModernBERT-base)

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.

    Args:
        vocab_size (`int`, *optional*, defaults to 50368):
            Vocabulary size of the ModernBert model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`ModernBertModel`]
        hidden_size (`int`, *optional*, defaults to 768):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 1152):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 22):
            Number of hidden layers in the Transformer decoder.
        num_attention_heads (`int`, *optional*, defaults to 12):
            Number of attention heads for each attention layer in the Transformer decoder.
        hidden_activation (`str` or `function`, *optional*, defaults to `"gelu"`):
            The non-linear activation function (function or string) in the decoder. Will default to `"gelu"`
            if not specified.
        max_position_embeddings (`int`, *optional*, defaults to 8192):
            The maximum sequence length that this model might ever be used with.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        initializer_cutoff_factor (`float`, *optional*, defaults to 2.0):
            The cutoff factor for the truncated_normal_initializer for initializing all weight matrices.
        norm_eps (`float`, *optional*, defaults to 1e-05):
            The epsilon used by the rms normalization layers.
        norm_bias (`bool`, *optional*, defaults to `False`):
            Whether to use bias in the normalization layers.
        pad_token_id (`int`, *optional*, defaults to 50283):
            Padding token id.
        eos_token_id (`int`, *optional*, defaults to 50282):
            End of stream token id.
        bos_token_id (`int`, *optional*, defaults to 50281):
            Beginning of stream token id.
        cls_token_id (`int`, *optional*, defaults to 50281):
            Classification token id.
        sep_token_id (`int`, *optional*, defaults to 50282):
            Separation token id.
        global_rope_theta (`float`, *optional*, defaults to 160000.0):
            The base period of the global RoPE embeddings.
        attention_bias (`bool`, *optional*, defaults to `False`):
            Whether to use a bias in the query, key, value and output projection layers during self-attention.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        global_attn_every_n_layers (`int`, *optional*, defaults to 3):
            The number of layers between global attention layers.
        local_attention (`int`, *optional*, defaults to 128):
            The window size for local attention.
        local_rope_theta (`float`, *optional*, defaults to 10000.0):
            The base period of the local RoPE embeddings.
        embedding_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the embeddings.
        mlp_bias (`bool`, *optional*, defaults to `False`):
            Whether to use bias in the MLP layers.
        mlp_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the MLP layers.
        decoder_bias (`bool`, *optional*, defaults to `True`):
            Whether to use bias in the decoder layers.
        classifier_pooling (`str`, *optional*, defaults to `"cls"`):
            The pooling method for the classifier. Should be either `"cls"` or `"mean"`. In local attention layers, the
            CLS token doesn't attend to all tokens on long sequences.
        classifier_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the classifier.
        classifier_bias (`bool`, *optional*, defaults to `False`):
            Whether to use bias in the classifier.
        classifier_activation (`str`, *optional*, defaults to `"gelu"`):
            The activation function for the classifier.
        deterministic_flash_attn (`bool`, *optional*, defaults to `False`):
            Whether to use deterministic flash attention. If `False`, inference will be faster but not deterministic.
        sparse_prediction (`bool`, *optional*, defaults to `False`):
            Whether to use sparse prediction for the masked language model instead of returning the full dense logits.
        sparse_pred_ignore_index (`int`, *optional*, defaults to -100):
            The index to ignore for the sparse prediction.
        reference_compile (`bool`, *optional*):
            Whether to compile the layers of the model which were compiled during pretraining. If `None`, then parts of
            the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
            shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
            be faster in some scenarios.
        repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
            When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
            applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.

    Examples:

    ```python
    >>> from transformers import ModernBertModel, ModernBertConfig

    >>> # Initializing a ModernBert style configuration
    >>> configuration = ModernBertConfig()

    >>> # Initializing a model from the modernbert-base style configuration
    >>> model = ModernBertModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```
modernbert
rope_thetaglobal_rope_thetapast_key_values             gelu    {Gz?       @h㈵>Fk  j  i       A        r           @TclsNclassifier_poolingr8   meanc$           	      ,    t                      j        d|||||d|$ || _        || _        || _        || _        || _        || _        || _        |	| _	        |
| _
        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        || _        | | _        |!| _        |"| _        |#| _        | j        dvrtA          d| j         d          d S )N)pad_token_idbos_token_ideos_token_idcls_token_idsep_token_idr;   zQInvalid value for `classifier_pooling`, should be either "cls" or "mean", but is . )!super__init__
vocab_sizemax_position_embeddingshidden_sizeintermediate_sizenum_hidden_layersnum_attention_headsinitializer_rangeinitializer_cutoff_factornorm_eps	norm_biasr%   attention_biasattention_dropouthidden_activationglobal_attn_every_n_layerslocal_attentionlocal_rope_thetaembedding_dropoutmlp_biasmlp_dropoutdecoder_biasr:   classifier_dropoutclassifier_biasclassifier_activationdeterministic_flash_attnsparse_predictionsparse_pred_ignore_indexreference_compilerepad_logits_with_grad
ValueError)&selfrG   rI   rJ   rK   rL   rS   rH   rM   rN   rO   rP   r>   r@   r?   rA   rB   r%   rQ   rR   rT   rU   rV   rW   rX   rY   rZ   r:   r[   r\   r]   r^   r_   r`   ra   rb   kwargs	__class__s&                                        /home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/transformers/models/modernbert/modular_modernbert.pyrF   zModernBertConfig.__init__   sl   N 	 	
%%%%%	
 	
 	
 	
 	
 %'>$&!2!2#6 !2)B& "!2,!2!2*D'. 0!2 &("4"4.%:"(@%!2(@%!2&<#"/99~dhd{~~~   :9    c                 t    t                                                      }|                    dd            |S )Nra   )rE   to_dictpop)rd   outputrf   s     rg   rj   zModernBertConfig.to_dict   s0    ""

&---rh   )#r'   r(   r)   r*   r+   r,   r-   r.   r/   r0   Fr1   r2   r3   r3   r2   r4   Fr5   r   r6   r7   r5   Fr5   Tr8   r5   Fr,   FFr9   NF)__name__
__module____qualname____doc__
model_typeattribute_mapkeys_to_ignore_at_inferencer   rF   rj   __classcell__rf   s   @rg   r"   r"   7   s       e eN J!#67M#4"5   $"%"#$ 5:$!&!%$IQ Q8 $M29Q Q Q Q Q Qf        rh   r"   inputsattention_maskposition_idslabelsreturnc                    |                     dt          j                  }t          j        |                                d                                          }t          |                                                                          }t          j        j	        
                    t          j        |dt          j                  d          }|                                 dk    r|                                 |         }n#| j        ^}	}
}|	|
z  } | j        |g|R  |         }||                                |         nd}||                                |         nd}||||||fS )	a  
    Remove padding from input sequences.

    Args:
        inputs: (batch, seqlen, ...) or (batch, seqlen)
        attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
        position_ids: (batch, seqlen), int, position ids
        labels: (batch, seqlen), int, labels

    Returns:
        unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
        indices: (total_nnz)
        cu_seqlens: (batch + 1), the cumulative sequence lengths
        max_seqlen_in_batch: int
        unpadded_position_ids: (total_nnz) or None
        unpadded_labels: (total_nnz) or None
    dimdtypeF)as_tupler   )   r   r   N)sumtorchint32nonzeroflattenintmaxitemr   
functionalpadcumsumr~   shapeview)rv   rw   rx   ry   seqlens_in_batchindicesmax_seqlen_in_batch
cu_seqlensunpadded_inputsbatchseqlenrestr   unpadded_position_idsunpadded_labelss                  rg   _unpad_modernbert_inputr      sW   . &))b)DDmN2244uEEEMMOOG.224499;;<<$((6FAUZU`)a)a)acijjJzz||q ..**73%|v%&+e3d333G<?K?WL00227;;]a393Efnn&&w//4OGZ1DF[]lllrh   r   r   r   c                 6   |                                  dk    r@t          j        ||z  | j        | j                  }| ||<   |                    ||          }n@| j        ^}}t          j        ||z  g|R | j        | j        d}| ||<    |j        ||g|R  }|S )aQ  
    Add padding to sequences.

    Args:
        inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
        indices: (total_nnz)
        batch: int, batch size
        seqlen: int, max sequence length

    Returns:
        padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
    r   r   device)r~   r   zerosr   r   r   r   )rv   r   r   r   rl   padded_inputs_r   s           rg   _pad_modernbert_outputr   %  s    $ zz||qUV^6<VVV wE622<DUV^]d]]&,v}]]] w#E69D999rh   c                   l    e Zd Ze	 	 ddeej                 dee         fd            Zed             Z	dS )ApplyRotaryEmbUnpadNr   
max_seqlenc           
          |                                 }|j        \  }}}}	|d d d df                             |d|	          }
t          |
||d||dd           |                     |||           || _        |S )Nr   r|   r   FT)seqlen_offsetsr   r   interleavedinplace)
contiguousr   r   r    save_for_backwardr   )ctxqkvcossinr   r   	total_nnz_three_nheadsheaddimqks              rg   forwardzApplyRotaryEmbUnpad.forwardE  s     nn.1i+	67G BQBZ__YG44!!		
 		
 		
 		
 	c3
333#
rh   c                     | j         \  }}}|                                }|j        \  }}}}|d d d df                             |d|          }	t	          |	||d|| j        ddd	  	         |d d d d d d fS )Nr   r|   r   FT)r   r   r   r   r   	conjugate)saved_tensorsr   r   r   r    r   )
r   dor   r   r   r   r   r   r   dqks
             rg   backwardzApplyRotaryEmbUnpad.backwardd  s    "0S*]]__.0h+	67G BQBinnYG44!~
	
 
	
 
	
 
	
 4tT455rh   NN)
rm   rn   ro   staticmethodr   r   Tensorr   r   r   rD   rh   rg   r   r   D  sy         .2$( 
 U\* SM   \< 6 6 \6 6 6rh   r   r   r   c                 >    t                               | ||||          S )a  
    Arguments:
        qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
        cos, sin: (seqlen_rotary, rotary_dim / 2)
        interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
            of 1st half and 2nd half (GPT-NeoX style).
        inplace: if True, apply rotary embedding in-place.
        seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
            Most commonly used in inference when we have KV cache.
        cu_seqlens: (batch + 1,) or None
        max_seqlen: int
    Return:
        out: (total_nnz, dim)
    rotary_dim must be <= headdim
    Apply rotary embedding to the first rotary_dim of x.
    )r   apply)r   r   r   r   r   s        rg   apply_rotary_unpaddedr   {  s     . $$S#sJ
KKKrh   c                        e Zd ZdZ	 	 	 	 ddededee         deej                 deej	                 f
 fd	Z
	 dd
ej        dej        dee         deej        eej        ej        f         f         fdZdefdZ xZS )!ModernBertUnpaddedRotaryEmbeddingzP
    The rotary position embeddings applied directly to unpadded sequences.
    r7   Nr~   baser   r   r   c                     t                                          |||d           || _        ||||                     |||           dS dS dS dS )a  
        max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
            up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
            the cos_sin_cache will be recomputed during the forward pass.
        F)r~   r   r   r   Nr   r   )rE   rF   r   _update_cos_sin_cache)rd   r~   r   r   r   r   rf   s         rg   rF   z*ModernBertUnpaddedRotaryEmbedding.__init__  sr     	StFNNN$!f&8U=N&&z&&NNNNN "!&8&8=N=Nrh   r   r   rz   c                     |"|                      ||j        |j                   t          || j        | j        ||          }|S )z
        Apply rotary embedding *inplace* to qkv.
        qkv: (total_nnz, 3, nheads, headdim)
        cu_seqlens: (batch + 1,) cumulative sequence lengths
        max_seqlen: int max seq length in the batch
        Nr   r   r   )r   r   r   r   _cos_cached_sin_cached)rd   r   r   r   s       rg   r   z)ModernBertUnpaddedRotaryEmbedding.forward  sY     !&&z#*CI&VVV#!!
 
 
 
rh   c                 6    d| j          d| j         d| j         S )Nzdim=z, base=z, scale_base=)r~   r   
scale_baserd   s    rg   
extra_reprz,ModernBertUnpaddedRotaryEmbedding.extra_repr  s&    PdhPPtyPPtPPPrh   )r7   NNNN)rm   rn   ro   rp   r   floatr   r   r   r   rF   r   r   tupler   strr   rt   ru   s   @rg   r   r     s2         $()-'+O OO O SM	O
 &O $O O O O O O. %)	 \ L SM	
 
u|U5<#=>>	?   2QC Q Q Q Q Q Q Q Qrh   r   c                        e Zd ZdZdef fdZ ej        d          dej        dej	        fd            Z
	 ddeej                 d
eej	                 dej	        fdZ xZS )ModernBertEmbeddingszV
    Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
    configc                 >   t                                                       || _        t          j        |j        |j        |j                  | _        t          j	        |j        |j
        |j                  | _        t          j        |j                  | _        d S )N)padding_idxepsbias)rE   rF   r   r   	EmbeddingrG   rI   r>   tok_embeddings	LayerNormrO   rP   normDropoutrW   droprd   r   rf   s     rg   rF   zModernBertEmbeddings.__init__  s{     l6+<f>P^d^qrrrL!3vO_```	Jv788			rh   Tdynamic	input_idsrz   c                 x    |                      |                     |                     |                              S r   )r   r   r   )rd   r   s     rg   compiled_embeddingsz(ModernBertEmbeddings.compiled_embeddings  s.    yy4#6#6y#A#ABBCCCrh   Ninputs_embedsc                    |)|                      |                     |                    }n\| j        j        r|                     |          n:|                      |                     |                     |                              }|S r   )r   r   r   ra   r   r   )rd   r   r   hidden_statess       rg   r   zModernBertEmbeddings.forward  s     $ IIdii&>&>??MM ;0J((333YYtyy)<)<Y)G)GHHII 
 rh   r   )rm   rn   ro   rp   r"   rF   r   compile
LongTensorr   r   r   r   rt   ru   s   @rg   r   r     s         9/ 9 9 9 9 9 9 U]4   DU-= D%, D D D ! D ei !%"23KSTYT`Ka	       rh   r   c                   L     e Zd ZdZdef fdZdej        dej        fdZ xZ	S )ModernBertMLPa6  Applies the GLU at the end of each ModernBERT layer.

    Compared to the default BERT architecture, this block replaces :class:`~transformers.model.bert.modeling_bert.BertIntermediate`
    and :class:`~transformers.model.bert.modeling_bert.SelfOutput` with a single module that has similar functionality.
    r   c                    t                                                       || _        t          j        |j        t          |j                  dz  |j                  | _	        t          |j                 | _        t          j        |j                  | _        t          j        |j        |j        |j                  | _        d S )Nr   r   )rE   rF   r   r   LinearrI   r   rJ   rX   Wir   rS   actr   rY   r   Wor   s     rg   rF   zModernBertMLP.__init__  s    )F.F4L0M0MPQ0QX^Xghhh&23Jv122	)F4f6Hv___rh   r   rz   c                     |                      |                              dd          \  }}|                     |                     |                     |          |z                      S )Nr   r|   r~   )r   chunkr   r   r   )rd   r   inputgates       rg   r   zModernBertMLP.forward  sW    ggm,,221"2==twwtyy%4!788999rh   )
rm   rn   ro   rp   r"   rF   r   r   r   rt   ru   s   @rg   r   r     s|         `/ ` ` ` ` ` `:U\ :el : : : : : : : :rh   r   c                       e Zd ZdS )ModernBertRotaryEmbeddingN)rm   rn   ro   rD   rh   rg   r   r     s        Drh   r   FmoduleModernBertAttentionr   sliding_window_maskrU   bsr~   output_attentionsc	                    |                      ||          \  }
}|                    dd                              d          \  }}}t          |||
|          \  }}| j        dz  }t          j        ||                    dd                    |z  }|dk    r|}||z   }t          j        	                    |dt
          j
        	                              |j                  }t          j                            || j        | j        
          }t          j        ||          }|                    dd                                          }|                    |d|          }|r||fS |fS )Nrx   r   r   r   r         ࿩r|   r|   r|   r}   )ptraining)
rotary_emb	transposeunbindr   head_dimr   matmulr   r   softmaxfloat32tor   dropoutrR   r  r   r   )r   r   rw   r   rx   rU   r   r~   r   _kwargsr   r   querykeyvaluescaleattn_weightsattn_outputs                     rg   eager_attention_forwardr     se      < @@HCa++22q299E3%eS#s;;JE3OT!E<s}}Q':':;;eCL("",.0L =((2U](SSVVW\WbccL=((9Q\b\k(llL,|U33K''1--88::K""2r3//K +\**>rh   r  target_dtypec	                     ||||          }|j         t          j        t          j        fv}
|
rZ|j         }|                    |          }t          |||| j        r| j        nd| j        |          }|                    |          }n(t          |||| j        r| j        nd| j        |          }|	                    ||          fS )Nr   r5   )r   r   	dropout_pdeterministicwindow_size)
r   r   float16bfloat16r  r   r  rR   r^   r   )r   r   r  r   r   rU   r   r~   r  r  convert_dtype
orig_dtypeattns                rg   flash_attention_forwardr   %  s     *SZJ
G
G
GCIemU^%DDM 
 Y
ff\""/!!28/Jf..s 9'
 
 
 wwz""/!!28/Jf..s 9'
 
 
 IIb#  rh   c                    |                      ||          \  }	}
|                    dd                              d          \  }}}t          |||	|
          \  }}|dk    r|}t	          j        |||| j        r| j        nd|                              dd                                          }|	                    |d	|          }|fS )
Nr   r   r   r   r   r  r5   )r  	attn_maskr|   )
r  r  r  r   Fscaled_dot_product_attentionr  rR   r   r   )r   r   rw   r   rx   rU   r   r~   r  r   r   r  r  r  r  s                  rg   sdpa_attention_forwardr%  P  s       < @@HCa++22q299E3%eS#s;;JE3("", 	
&28/Jf..s$	
 	
 	
 
1a	  ""2r3//K>rh   )flash_attention_2eagersdpac                   r     e Zd ZdZddedee         f fdZ	 ddej	        dee
         d	ej	        fd
Z xZS )r   a  Performs multi-headed self attention on a batch of unpadded sequences.

    If Flash Attention 2 is installed, this module uses Flash Attention to improve throughput.
    If Flash Attention 2 is not installed, the implementation will use PyTorch's SDPA kernel,
    which requires padding and unpadding inputs, adding some overhead.

    See `forward` method for additional details.
    Nr   layer_idc                    t                                                       || _        || _        |j        |j        z  dk    r t          d|j         d|j         d          |j        | _        |j        | _        |j        | _	        |j        |j        z  | _
        | j
        | j	        z  | _        t          j        |j        d| j        z  |j                  | _        ||j        z  dk    r6|j        dz  |j        dz  f| _        |j        |j        n|j        }|j        }nd| _        |j        }|j        }|j        d	k    rt-          | j
        ||
          | _        n0t1          j        |          }||_        t7          |          | _        t          j        |j        |j        |j                  | _        |j        dk    rt          j        |j                  nt          j                    | _        tA                      | _!        d S )Nr   zThe hidden size (z6) is not a multiple of the number of attention heads ()r   r   r   r  r&  )r~   r   r   )r   r5   )"rE   rF   r   r*  rI   rL   rc   rR   r^   	num_headsr  all_head_sizer   r   rQ   WqkvrT   rU   rV   r%   rH   _attn_implementationr   r  copydeepcopyr$   r   r   r   Identityout_dropsetpruned_heads)rd   r   r*  r$   rH   config_copyrf   s         rg   rF   zModernBertAttention.__init__  s     ::a?? LF$6  L  Lnt  oI  L  L  L   "(!9(.(G%3*f.HH!]T^;If0!d6H2HvOdeee	f771<<$*$:a$?AW[\A\#]D 4:4K4W00]c]uJ&,&<###+D &,&D#1J&*===?M.EJ  DOO -//K%/K"7{KKKDO)F.0BI^___@F@X[^@^@^
6#;<<<dfdodqdqEErh   Fr   r   rz   c           
         |                      |          }|j        d         }| j        j        dk    r#|                    dd| j        | j                  }n#|                    |dd| j        | j                  }t          | j        j                 | f|| j        | j	        || j
        |d|}|d         }|                     |                     |                    }|f|dd          z   S )Nr   r&  r|   r   )r   r  rU   r   r~   r   r   )r/  r   r   r0  r   r-  r  MODERNBERT_ATTENTION_FUNCTIONr  rU   r.  r4  r   )rd   r   r   re   r   r   attn_outputss          rg   r   zModernBertAttention.forward  s     ii&& #;+/BBB((2q$.$-@@CC((2r1dndmDDC4T[5UV	
 0"/	
 	
 	
 	
 %Qdggm&<&<==,qrr"222rh   r   F)rm   rn   ro   rp   r"   r   r   rF   r   r   boolr   rt   ru   s   @rg   r   r   z  s         %" %"/ %"8C= %" %" %" %" %" %"T -23 3|3 $D>3
 
3 3 3 3 3 3 3 3rh   c                   B    e Zd Zddedee         f fdZ ej        d          dej	        dej	        fd	            Z
	 	 	 	 	 	 ddej	        deej	                 deej	                 deej                 deej	                 dee         dee         dej	        fdZ xZS )ModernBertEncoderLayerNr   r*  c                    t                                                       || _        |dk    rt          j                    | _        n+t          j        |j        |j        |j	                  | _        t          ||          | _        t          j        |j        |j        |j	                  | _        t          |          | _        d S )Nr   r   )r   r*  )rE   rF   r   r   r3  	attn_normr   rI   rO   rP   r   r  mlp_normr   mlp)rd   r   r*  rf   s      rg   rF   zModernBertEncoderLayer.__init__  s    q==[]]DNN\&*<&/X^XhiiiDN'vIII	V%7V_SYScddd ((rh   Tr   r   rz   c                 R    |                      |                     |                    S r   )rB  rA  rd   r   s     rg   compiled_mlpz#ModernBertEncoderLayer.compiled_mlp  s     xxm44555rh   Frw   r   rx   r   r   r   c           	      .   |                      |                     |          ||||||          }||d         z   }| j        j        r|                     |          n'|                     |                     |                    }	||	z   }|f|dd          z   S )Nrw   r   rx   r   r   r   r   r   )r  r@  r   ra   rE  rB  rA  )
rd   r   rw   r   rx   r   r   r   r:  
mlp_outputs
             rg   r   zModernBertEncoderLayer.forward  s     yyNN=))) 3%!!/ ! 
 
 &Q7 {,8Dm,,,$--6677 	
 &
2,qrr"222rh   r   )NNNNNF)rm   rn   ro   r"   r   r   rF   r   r   r   rE  r   r<  r   rt   ru   s   @rg   r>  r>    s;       	) 	)/ 	)8C= 	) 	) 	) 	) 	) 	) U]4   6%, 65< 6 6 6 ! 6 266:37-1$(,13 3|3 !.3 &el3	3
 u/03 U\*3 SM3 $D>3 
3 3 3 3 3 3 3 3rh   r>  c                        e Zd ZU eed<   dZdZddgZdZdZ	dZ
dej        fdZ	 dd	ee         d
edef fdZd Z fdZ xZS )ModernBertPreTrainedModelr   modelTr   r>  Fr   c                 p   | j         j        ddt          j        dt          ffd}| j         j        | j         j        t          j        d| j         j        z            z  | j         j        | j         j	        dz  d}t          |t                    r ||j        |d                    d S t          |t                    r0 ||j        |d	                     ||j        |d
                    d S t          |t                     r0 ||j        |d	                     ||j        |d
                    d S t          |t$                    r ||j        |d
                    d S t          |t(                    r ||j        |d
                    d S t          |t,          t.          t0          t2          f          r ||j        |d                    d S t          |t          j                  rF|j        j                            d           |j        "|j        j                                          d S d S d S )Nr   r   stdc                     t           j                            | j        d| |z  |z             t	          | t           j                  r-| j        (t           j                            | j                   d S d S d S )Nr5   )r<   rM  ab)r   inittrunc_normal_weight
isinstancer   r   zeros_)r   rM  cutoff_factors     rg   init_weightz<ModernBertPreTrainedModel._init_weights.<locals>.init_weight  s    G!! .3&#% "    &"),, 0;*GNN6;/////0 0**rh   r/   r  )inout	embedding	final_outrZ  rX  rY  r[  g      ?)!r   rN   r   Moduler   rM   mathsqrtrK   rI   rT  r   r   r   r   r   r   r/  ModernBertPredictionHeaddenseModernBertForMaskedLMdecoder#ModernBertForSequenceClassificationModernBertForMultipleChoice ModernBertForTokenClassificationModernBertForQuestionAnswering
classifierr   rS  datafill_r   zero_)rd   r   rW  stdsrV  s       @rg   _init_weightsz'ModernBertPreTrainedModel._init_weights  sd   = M	0	 	0 	0 	0 	0 	0 	0 	0 +/;049S4;C`=`3a3aa60$6	
 
 f233 	)K-tK/@AAAAA.. 	)K	4:...K	4;///// 344 	)KT$Z000K	4;///// 899 	)Kd5k22222 566 	)KU444443+0.	
 
 	) K)4+<=====-- 	)M$$S))){& &&(((((	) 	)&&rh   attn_implementationis_init_checkrz   c                     	 ||                                  rdn|}n# t          t          f$ r Y nw xY wt                                          ||          S )zR
        Checks and dispatches to hhe requested attention implementation.
        Nr&  )rm  rn  )_flash_attn_2_can_dispatchrc   ImportErrorrE   %_check_and_adjust_attn_implementation)rd   rm  rn  rf   s      rg   rr  z?ModernBertPreTrainedModel._check_and_adjust_attn_implementation5  s    	 '.43R3R3T3T. $#(  
 K( 	 	 	D	ww<< 3= = 
 
 	
s    22c                 .   | j         j        du rd S t          | d          rJt          | j                  dk    r2| j         j        rt
                              d           d| j         _        | j        j        dk    r2| j         j        rt
                              d           d| j         _        | j        j        dk    r2| j         j        rt
                              d           d| j         _        | j         j        t                      | j         _        d S d S )	NFhf_device_mapr   zqIf `accelerate` split the model across devices, `torch.compile` will not work. Falling back to non-compiled mode.mpsz|Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. Falling back to non-compiled mode.cpuz|Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. Falling back to non-compiled mode.)
r   ra   hasattrlenrt  loggerwarning_oncer   typer   r   s    rg   _maybe_set_compilez,ModernBertPreTrainedModel._maybe_set_compileL  s)   ;(E11F4)) 	2c$2D.E.E.I.I{, ##9   -2DK);u$${, ##9   -2DK);u$${, ##9   -2DK);(0,?,A,ADK))) 10rh   c                      t                      j        |i |}| j        j        dv r2| j        j        rt                              d           d| j        _        |S )N>   NTzcResizing token embeddings with `torch.compile` is not supported. Falling back to non-compiled mode.F)rE   resize_token_embeddingsr   ra   ry  rz  )rd   argsre   model_embedsrf   s       rg   r~  z1ModernBertPreTrainedModel.resize_token_embeddingsk  sh    6uww6GGG;(L88{, ##y   -2DK)rh   r;  )rm   rn   ro   r"   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modules_supports_flash_attn_supports_sdpa_supports_flex_attnr   r\  rl  r   r   r<  rr  r|  r~  rt   ru   s   @rg   rJ  rJ    s         &*#/1IJN2)BI 2) 2) 2) 2)j IN
 
#+C=
AE
	
 
 
 
 
 
.B B B>
 
 
 
 
 
 
 
 
rh   rJ  c            !           e Zd Zdef fdZd Zd Ze	 	 	 	 	 	 	 	 	 	 	 	 	 ddee	j
                 dee	j                 dee	j                 d	ee	j
                 d
ee	j                 dee	j                 dee	j                 dee         dee         dee         dee         dee         dee         deee	j        df         ef         fd            Zde	j        dede	j        fdZ xZS )ModernBertModelr   c                 |   t                                                     | _        t                    | _        t          j        fdt          j                  D                       | _	        t          j
        j        j        j                  | _        d| _        |                                  d S )Nc                 0    g | ]}t          |          S rD   )r>  ).0r*  r   s     rg   
<listcomp>z,ModernBertModel.__init__.<locals>.<listcomp>  s$    fff(#FH55fffrh   r   F)rE   rF   r   r   
embeddingsr   
ModuleListrangerK   layersr   rI   rO   rP   
final_normgradient_checkpointing	post_initr   s    `rg   rF   zModernBertModel.__init__z  s       .v66mffffeFLdFeFefff
 
 ,v'9vU[Uefff&+#rh   c                     | j         j        S r   r  r   r   s    rg   get_input_embeddingsz$ModernBertModel.get_input_embeddings  s    --rh   c                     || j         _        d S r   r  )rd   r  s     rg   set_input_embeddingsz$ModernBertModel.set_input_embeddings  s    ).&&&rh   Nr   rw   r   rx   r   r   r   r   
batch_sizeseq_lenr   output_hidden_statesreturn_dictrz   .c           
        	
 ||n| j         j        }||n| j         j        }||n| j         j        }|du |duz  rt	          d          |rdnd}|rdnd}|                                  ||                     ||           	)
'||j        dd         \  	
n|j        dd         \  	
||j        n|j        }|#t          j
        	
f|t          j                  }d}| j         j        dk    rc`|^|\d}|Bt          j                    5  t          ||	          ^}}}}ddd           n# 1 swxY w Y   n\t          ||	          ^}}}}nE|)t          j        
|
                              d          }|                     ||          \  }}|                     ||          }| j        D ]E}|r||fz   } ||||||||          }|d         }|rt)          |          dk    r||d         fz   }F|r||fz   }|                     |          }|r3t-          |	
          }|t/          	
fd|D                       }n^| j         j        dk    rN|L|d                                         dk    r.|                    d          }t/          d |D                       }|st/          d |||fD                       S t3          |||          S )  
        sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
            perform global attention, while the rest perform local attention. This mask is used to avoid attending to
            far-away tokens in the local attention layers when not using Flash Attention.
        indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
            Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
        cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
            Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
        max_seqlen (`int`, *optional*):
            Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
        batch_size (`int`, *optional*):
            Batch size of the input sequences. Used to pad the output tensors.
        seq_len (`int`, *optional*):
            Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
        Nz:You must specify exactly one of input_ids or inputs_embedsrD   r   r   Fr&  T)rv   rw   r   r   )r   )r   r   rG  r   rv   r   r   r   c              3   >   K   | ]}t          |           V  dS )r  N)r   )r  hsr  r   r  s     rg   	<genexpr>z*ModernBertModel.forward.<locals>.<genexpr>  sI       * * +"gZ`ghhh* * * * * *rh   r|   c              3   @   K   | ]}|                     d           V  dS )r   N)	unsqueeze)r  r  s     rg   r  z*ModernBertModel.forward.<locals>.<genexpr>  s,      %R%R"bll1oo%R%R%R%R%R%Rrh   c              3      K   | ]}||V  	d S r   rD   )r  vs     rg   r  z*ModernBertModel.forward.<locals>.<genexpr>  s(      mmq_`_l_l_l_l_lmmrh   )last_hidden_stater   
attentions)r   r   r  use_return_dictrc   r|  %warn_if_padding_and_no_attention_maskr   r   r   onesr<  r0  no_gradr   aranger  _update_attention_maskr  r  rx  r  r   r   r~   r   )rd   r   rw   r   rx   r   r   r   r   r  r  r   r  r  all_hidden_statesall_self_attentionsr   repadr   r   encoder_layerlayer_outputss         `  ``           rg   r   zModernBertModel.forward  sb   B 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B]-t";< 	[YZZZ"6@BBD$5?bb4!!! 66y.QQQ'/(&3&9"1"&=#
GG&/obqb&9#
G%.%:!!@T!"ZW(=fTYT^___N;+/BBB:#5*:L (  I`#,^J J JF	7J
Q              
 Ja,^J J JFM7J
Q #$|GFCCCMMaPP262M2M2C 3N 3 3/N/ )=YY![ 	P 	PM# I$58H$H!)M-$7)%%"3  M *!,M  PS%7%7!%;%;&9]1=M<O&O# 	E 1]4D D66 	S2$gZPW  M !,$) * * * * * */* * * % %! K,0CCC!-!"%))++q00)33A66M %%R%R@Q%R%R%R R R 	nmm]4EGZ$[mmmmmm++*
 
 
 	
s   D66D:=D:c                    |ro| j         j        dk    r't                              d           d| j         _        n8| j         j        dk    r(t                              d| j         j         d           t	          || j                  }t          j        |j        d                   	                    d          }t          j
        ||j        z
            }|| j         j        dz  k    	                    d          	                    d                              |j                  }|                    |                                t          j        | j                  j                  }||fS )Nr(  zOutputting attentions is only supported with the 'eager' attention implementation, not with "sdpa". Falling back to `attn_implementation="eager"`.r'  zZOutputting attentions is only supported with the eager attention implementation, not with zT. Consider setting `attn_implementation="eager"`. Setting `output_attentions=False`.r   r   )r   r0  ry  rz  r   r   r   r  r   r  absTrU   r  r   masked_filllogical_notfinfomin)rd   rw   r   global_attention_maskrowsdistancewindow_maskr   s           rg   r  z&ModernBertModel._update_attention_mask  sg    	{/699##V   4;001W<<##: $ @: : :   !;>4: V V |17:;;EEaHH9TDF]++ 499DDQGGQQRSTTWWXfXmnn 	 4??@W@W@Y@Y[`[fgkgq[r[r[vww$&999rh   NNNNNNNNNNNNN)rm   rn   ro   r"   rF   r  r  r   r   r   r   r   r   r<  r   r   r   r   r  rt   ru   s   @rg   r  r  x  s       	/ 	 	 	 	 	 	. . ./ / /  15156:3704*.-1$($(!%,0/3&*A
 A
E,-A
 !.A
 &el3	A

 u/0A
  -A
 %,'A
 U\*A
 SMA
 SMA
 #A
 $D>A
 'tnA
 d^A
 
uU\3&'8	9A
 A
 A
 ^A
F:U\ :VZ :_d_k : : : : : : : :rh   r  c                   H     e Zd Zdef fdZdej        dej        fdZ xZS )r_  r   c                 .   t                                                       || _        t          j        |j        |j        |j                  | _        t          |j	                 | _
        t          j        |j        |j        |j                  | _        d S )Nr   )rE   rF   r   r   r   rI   r\   r`  r   r]   r   r   rO   rP   r   r   s     rg   rF   z!ModernBertPredictionHead.__init__0  sq    Yv163EvG]^^
&67L!3vO_```			rh   r   rz   c                 x    |                      |                     |                     |                              S r   )r   r   r`  rD  s     rg   r   z ModernBertPredictionHead.forward7  s,    yy$**]";";<<===rh   )	rm   rn   ro   r"   rF   r   r   r   rt   ru   s   @rg   r_  r_  /  sr        a/ a a a a a a>U\ >el > > > > > > > >rh   r_  zd
    The ModernBert Model with a decoder head on top that is used for masked language modeling.
    )custom_introc            "       (    e Zd ZdgZdef fdZd Zdej        fdZ	 e
j        d          d	e
j        d
e
j        fd            Ze	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddee
j                 dee
j                 dee
j                 dee
j                 dee
j                 dee
j                 dee
j                 dee
j                 dee         dee         dee         dee         dee         dee         d
eee
j                 ef         fd            Z xZS )ra  zdecoder.weightr   c                 j   t                                          |           || _        t          |          | _        t          |          | _        t          j        |j	        |j
        |j                  | _        | j        j        | _        | j        j        | _        |                                  d S )Nr   )rE   rF   r   r  rK  r_  headr   r   rI   rG   rZ   rb  r_   r`   r  r   s     rg   rF   zModernBertForMaskedLM.__init__C  s       $V,,
,V44	y!3V5FVM`aaa!%!>(,(L% 	rh   c                     | j         S r   rb  r   s    rg   get_output_embeddingsz+ModernBertForMaskedLM.get_output_embeddingsP  s
    |rh   new_embeddingsc                     || _         d S r   r  )rd   r  s     rg   set_output_embeddingsz+ModernBertForMaskedLM.set_output_embeddingsS  s    %rh   Tr   rl   rz   c                 R    |                      |                     |                    S r   )rb  r  )rd   rl   s     rg   compiled_headz#ModernBertForMaskedLM.compiled_headV  s     ||DIIf--...rh   Nr   rw   r   rx   r   ry   r   r   r   r  r  r   r  r  c                 (   ||n| j         j        }|                                  | j         j        dk    r|||	|
)|'||j        dd         \  }
}n|j        dd         \  }
}||j        n|j        }|#t          j        |
|f|t          j                  }|Ft          j	                    5  t          ||||          \  }}}}	}}ddd           n# 1 swxY w Y   nt          ||||          \  }}}}	}}|                     ||||||||	|
||||          }|d         }| j        rS|Q|                    d          }|                    |j        d         d          }|| j        k    }||         }||         }| j         j        r|                     |          n'|                     |                     |                    }d}| | j        ||fd	| j         j        i|}| j         j        dk    r| j         j        s|t-                      nt          j	                    5  t/          |||
|
          }ddd           n# 1 swxY w Y   t1          |dd          g }|j        D ]f}|                                dk    r&|j        d         dk    r|                    d          }|                    t/          |||
|
                     gt;          |          |_        |s|f}||f|z   n|S t=          |||j        |j                  S )r  Nr&  r   r   )rv   rw   rx   ry   r   rw   r   rx   r   r   r   r   r  r  r   r  r  r   r|   rG   r  r   r   r   losslogitsr   r  ) r   r  r|  r0  r   r   r   r  r<  r  r   rK  r_   r   r`   ra   r  rb  r  loss_functionrG   rb   r   r   getattrr   r~   squeezeappendr   r   r  )rd   r   rw   r   rx   r   ry   r   r   r   r  r  r   r  r  re   r   outputsr  mask_tokensr  r  padded_hidden_statesr  rl   s                            rg   r   zModernBertForMaskedLM.forwardZ  sX   F &1%<kk$+B]!!!;+/BBB:#5*:L%'/$0.;.A"1".E+
GG.7obqb.A+
G-6-B))H\!)%*ZW0Ef\a\f%g%g%gN (  [r#,^Zfou\ \ \X	7J
LRX              
 \s,^Zfou\ \ \XM7J
LRX **) 3%'!!!/!5#  
 
 $AJ! 	)f&8[[__F 1 6 6v|A K K !D$AAK 1+ >K(F {,<D0111dii(9::;; 	 %4%ffbbAWb[abbD;+/BBB"&+"Dk\a\i\k\k r r/vwV`ipqqqr r r r r r r r r r r r r r r w66B')$!/  Bvvxx1}}!)9)9ZZ]](//.b'Q[dklll    )..B(C(C% 	FYF)-)9TGf$$vE!/)	
 
 
 	
s$   0CCC(IIINNNNNNNNNNNNNN)rm   rn   ro   _tied_weights_keysr"   rF   r  r   r   r  r   r   r   r  r   r   r   r   r<  r   r   r   r   rt   ru   s   @rg   ra  ra  ;  s        ++/        &BI & & & & U]4   /EL /U\ / / / ! /  15156:/304)-*.-1$($(!%,0/3&*x
 x
E,-x
 !.x
 &el3	x

 u|,x
  -x
 &x
 %,'x
 U\*x
 SMx
 SMx
 #x
 $D>x
 'tnx
 d^x
" 
uU\"N2	3#x
 x
 x
 ^x
 x
 x
 x
 x
rh   ra  z`
    The ModernBert Model with a sequence classification head on top that performs pooling.
    c            "           e Zd Zdef fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                 deej	                 deej	                 deej	                 deej	                 d	eej	                 d
eej	                 deej	                 dee
         dee
         dee
         dee         dee         dee         deeej	                 ef         fd            Z xZS )rc  r   c                    t                                          |           |j        | _        || _        t	          |          | _        t          |          | _        t          j	        
                    |j                  | _        t          j        |j        |j                  | _        |                                  d S r   )rE   rF   
num_labelsr   r  rK  r_  r  r   r   r   r[   r   r   rI   rg  r  r   s     rg   rF   z,ModernBertForSequenceClassification.__init__  s        +$V,,
,V44	H$$V%>??	)F$68IJJ 	rh   Nr   rw   r   rx   r   ry   r   r   r   r  r  r   r  r  rz   c                    ||n| j         j        }|                                  ||                     ||           |
)|'||j        dd         \  }
}n|j        dd         \  }
}||j        n|j        }|#t          j        |
|f|t          j                  }| 	                    ||||||||	|
||||          }|d         }| j         j
        dk    r|dddf         }nT| j         j
        dk    rD||                    d          z                      d	
          |                    d	d          z  }|                     |          }|                     |          }|                     |          }d}|Z| j         j        f| j        d	k    rd| j         _        nN| j        d	k    r7|j        t          j        k    s|j        t          j        k    rd| j         _        nd| j         _        | j         j        dk    rWt+                      }| j        d	k    r1 ||                                |                                          }n |||          }n| j         j        dk    rGt/                      } ||                    d| j                  |                    d                    }n*| j         j        dk    rt3                      } |||          }|s|f}||f|z   n|S t5          |||j        |j                  S )aB  
        sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
            perform global attention, while the rest perform local attention. This mask is used to avoid attending to
            far-away tokens in the local attention layers when not using Flash Attention.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
            Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
        cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
            Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
        max_seqlen (`int`, *optional*):
            Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
        batch_size (`int`, *optional*):
            Batch size of the input sequences. Used to pad the output tensors.
        seq_len (`int`, *optional*):
            Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
        Nr   r   r  r   r8   r<   r|   r   r   Tr~   keepdim
regressionsingle_label_classificationmulti_label_classificationr  )r   r  r|  r  r   r   r   r  r<  rK  r:   r  r   r  r   rg  problem_typer  r   longr   r
   r  r	   r   r   r   r   r  )rd   r   rw   r   rx   r   ry   r   r   r   r  r  r   r  r  re   r   r  r  pooled_outputr  r  loss_fctrl   s                           rg   r   z+ModernBertForSequenceClassification.forward  s]   N &1%<kk$+B]!!! 66y.QQQ'/(&3&9"1"&=#
GG&/obqb&9#
G%.%:!!@T!"ZW(=fTYT^___N**) 3%'!!!/!5#  
 
 $AJ;)U22 1!!!Q$ 7[+v55!2^5M5Mb5Q5Q!Q V V[\ V ] ]`n`r`rt as a a ! 		"344		-00//{'/?a''/;DK,,_q((flej.H.HFL\a\eLeLe/LDK,,/KDK,{'<77"99?a''#8FNN$4$4fnn6F6FGGDD#8FF33DD)-JJJ+--xB @ @&++b//RR)-III,..x// 	FYF)-)9TGf$$vE'!/)	
 
 
 	
rh   r  )rm   rn   ro   r"   rF   r   r   r   r   r   r   r<  r   r   r   r   rt   ru   s   @rg   rc  rc    s       /        15156:/304)-*.-1$($(!%,0/3&*r
 r
E,-r
 !.r
 &el3	r

 u|,r
  -r
 &r
 %,'r
 U\*r
 SMr
 SMr
 #r
 $D>r
 'tnr
 d^r
" 
uU\"$<<	=#r
 r
 r
 ^r
 r
 r
 r
 r
rh   rc  zv
    The ModernBert Model with a token classification head on top, e.g. for Named Entity Recognition (NER) tasks.
    c            "           e Zd Zdef fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                 deej	                 deej	                 deej	                 deej	                 d	eej	                 d
eej	                 deej	                 dee
         dee
         dee
         dee         dee         dee         deeej	                 ef         fd            Z xZS )re  r   c                 t   t                                          |           |j        | _        t          |          | _        t          |          | _        t          j        	                    |j
                  | _        t          j        |j        |j                  | _        |                                  d S r   rE   rF   r  r  rK  r_  r  r   r   r   r[   r   r   rI   rg  r  r   s     rg   rF   z)ModernBertForTokenClassification.__init__e  s        +$V,,
,V44	H$$V%>??	)F$68IJJ 	rh   Nr   rw   r   rx   r   ry   r   r   r   r  r  r   r  r  rz   c                    ||n| j         j        }|                                  |                     ||||||||	|
||||          }|d         }|                     |          }|                     |          }|                     |          }d}|Ft                      } ||                    d| j	                  |                    d                    }|s|f|dd         z   }||f|z   n|S t          |||j        |j                  S )a  
        sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
            perform global attention, while the rest perform local attention. This mask is used to avoid attending to
            far-away tokens in the local attention layers when not using Flash Attention.
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
            Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
        cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
            Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
        max_seqlen (`int`, *optional*):
            Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
        batch_size (`int`, *optional*):
            Batch size of the input sequences. Used to pad the output tensors.
        seq_len (`int`, *optional*):
            Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
        Nr  r   r|   r   r  )r   r  r|  rK  r  r   rg  r	   r   r  r   r   r  )rd   r   rw   r   rx   r   ry   r   r   r   r  r  r   r  r  r  r  r  r  r  rl   s                        rg   r   z(ModernBertForTokenClassification.forwardq  sK   H &1%<kk$+B]!!!**) 3%'!!!/!5#  
 
 $AJ II&788 II&788!233'))H8FKKDO<<fkk"ooNND 	FY,F)-)9TGf$$vE$!/)	
 
 
 	
rh   r  )rm   rn   ro   r"   rF   r   r   r   r   r   r   r<  r   r   r   r   rt   ru   s   @rg   re  re  _  s       
/ 
 
 
 
 
 
  15156:/304)-*.-1$($(!%,0/3&*I
 I
E,-I
 !.I
 &el3	I

 u|,I
  -I
 &I
 %,'I
 U\*I
 SMI
 SMI
 #I
 $D>I
 'tnI
 d^I
  
uU\"$99	:!I
 I
 I
 ^I
 I
 I
 I
 I
rh   re  c            "           e Zd Zdef fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 deej                 deej                 d	eej                 d
eej                 deej                 dee	         dee	         dee	         dee
         dee
         dee
         deeej                 ef         fd            Z xZS )rf  r   c                 t   t                                          |           |j        | _        t          |          | _        t          |          | _        t          j        	                    |j
                  | _        t          j        |j        |j                  | _        |                                  d S r   r  r   s     rg   rF   z'ModernBertForQuestionAnswering.__init__  s        +$V,,
,V44	H$$V%>??	)F$68IJJrh   Nr   rw   r   rx   start_positionsend_positionsr   r   r   r  r  r   r  r  rz   c                    ||n| j         j        }|                                  |                     |||||||	|
||||          }|d         }|                     |          }|                     |          }|                     |          }|                    dd          \  }}|                    d          	                                }|                    d          	                                }d}|| | j
        ||||fi |}|s||f|dd         z   }||f|z   n|S t          ||||j        |j                  S )r  N)rw   r   rx   r   r   r   r  r  r   r  r  r   r   r|   r   )r  start_logits
end_logitsr   r  )r   r  r|  rK  r  r   rg  splitr  r   r  r   r   r  )rd   r   rw   r   rx   r  r  r   r   r   r  r  r   r  r  re   r  r  r  r  r  r  rl   s                          rg   r   z&ModernBertForQuestionAnswering.forward  s   F &1%<kk$+B]!!!**) 3%!!!/!5#  
 
 $AJ II&788 II&788!233#)<<r<#:#: j#++B//::<<''++6688
&=+D%4%lJQ^iibhiiD 	F"J/'!""+=F)-)9TGf$$vE+%!!/)
 
 
 	
rh   r  )rm   rn   ro   r"   rF   r   r   r   r   r   r<  r   r   r   r   rt   ru   s   @rg   rf  rf    s       	/ 	 	 	 	 	 	  266:/32604*.-1$($(!%,0/3&*K
 K
EL)K
 !.K
 &el3	K

 u|,K
 "%,/K
  -K
 %,'K
 U\*K
 SMK
 SMK
 #K
 $D>K
 'tnK
 d^K
" 
uU\"$@@	A#K
 K
 K
 ^K
 K
 K
 K
 K
rh   rf  z
    The ModernBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks.
    c            "           e Zd Zdef fdZe	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                 deej	                 deej	                 deej	                 deej	                 d	eej	                 d
eej	                 deej	                 dee
         dee
         dee
         dee         dee         dee         deeej	                 ef         fd            Z xZS )rd  r   c                 `   t                                          |           || _        t          |          | _        t          |          | _        t          j        	                    |j
                  | _        t          j        |j        d          | _        |                                  d S )Nr   )rE   rF   r   r  rK  r_  r  r   r   r   r[   r   r   rI   rg  r  r   s     rg   rF   z$ModernBertForMultipleChoice.__init__   s       $V,,
,V44	H$$V%>??	)F$6:: 	rh   Nr   rw   r   rx   r   ry   r   r   r   r  r  r   r  r  rz   c                    ||n| j         j        }||j        d         n|j        d         }|)|                    d|                    d                    nd}|)|                    d|                    d                    nd}|)|                    d|                    d                    nd}|=|                    d|                    d          |                    d                    nd}|                                  |                     ||||||||	|
||||          }|d         }| j         j        dk    rt          j	        |j        d         |j
                  }|/|                    d	                              |j
                  }n&t          j        dt          j        |j
        
          }|||f         }nV| j         j        dk    rF|                    dd          }||                    d          z                      d	          |z  }|                     |          }|                     |          }|                     |          }|                    d|          }d}|t)          j                    } |||          }|s|f|dd         z   }||f|z   n|S t-          |||j        |j                  S )a  
        sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
            perform global attention, while the rest perform local attention. This mask is used to avoid attending to
            far-away tokens in the local attention layers when not using Flash Attention.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors.
        indices (`torch.Tensor` of shape `(total_unpadded_tokens,)`, *optional*):
            Indices of the non-padding tokens in the input sequence. Used for unpadding the output.
        cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
            Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
        max_seqlen (`int`, *optional*):
            Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
        batch_size (`int`, *optional*):
            Batch size of the input sequences. Used to pad the output tensors.
        seq_len (`int`, *optional*):
            Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
        Nr   r|   r  r   r8   r  r   r   r<   Tr  r  )r   r  r   r   sizer|  rK  r:   r   r  r   argmaxr  tensorr  r   r  r  r   rg  r   r	   r   r   r  )rd   r   rw   r   rx   r   ry   r   r   r   r  r  r   r  r  re   num_choicesr  r  	indices_0cls_masknum_non_pad_tokensr  r  reshaped_logitsr  r  rl   s                               rg   r   z#ModernBertForMultipleChoice.forward,  s!   L &1%<kk$+B],5,Aioa((}GZ[\G]>G>SINN2y~~b'9'9:::Y]	M[Mg,,R1D1DR1H1HIIImqGSG_|((\->->r-B-BCCCei ( r=#5#5b#9#9=;M;Mb;Q;QRRR 	 	!!!**) 3%'!!!/!5#  
 
 $AJ ;)U22%6%<Q%?HYH`aaaI))00R088;;<M<TUU !<DUD\]]] 1)X2E F [+v55!/!3!34!3!H!H!2^5M5Mb5Q5Q!Q V V[\ V ] ]`r r		"344		-00// ++b+66*,,H8OV44D 	F%''!""+5F)-)9TGf$$vE("!/)	
 
 
 	
rh   r  )rm   rn   ro   r"   rF   r   r   r   r   r   r   r<  r   r   r   r   rt   ru   s   @rg   rd  rd    s       
/ 
 
 
 
 
 
  15156:/304)-*.-1$($(!%,0/3&*i
 i
E,-i
 !.i
 &el3	i

 u|,i
  -i
 &i
 %,'i
 U\*i
 SMi
 SMi
 #i
 $D>i
 'tni
 d^i
" 
uU\"$==	>#i
 i
 i
 ^i
 i
 i
 i
 i
rh   rd  )r"   r  rJ  ra  rc  re  rf  rd  r   r;  )Wr1  r]  
contextlibr   typingr   r   r   r   torch.nn.functionalr   r   r#  torch.nnr   r	   r
   activationsr   configuration_utilsr   modeling_attn_mask_utilsr   modeling_layersr   modeling_outputsr   r   r   r   r   r   modeling_utilsr   utilsr   r   r   utils.import_utilsr   gemma.modeling_gemmar   r   flash_attn.flash_attn_interfacer   flash_attn.layers.rotaryr   flash_attn.ops.triton.rotaryr    object
get_loggerrm   ry  r"   r   r   r   r   r   autogradFunctionr   r   r   r\  r   r   r   r   r<  r  r  r   r   r%  r9  r   r>  rJ  r  r_  ra  rc  re  rf  rd  __all__rD   rh   rg   <module>r     s      " " " " " " + + + + + + + + + +                 A A A A A A A A A A ! ! ! ! ! ! 3 3 3 3 3 3 B B B B B B 9 9 9 9 9 9                . - - - - - G G G G G G G G G G 5 5 5 5 5 5 M M M M M M M M  PPPPPP8888889999999O 
	H	%	%B B B B B' B B BP ,0%)	&m &mL&mL&m 5<(&m U\"	&m
 5<u|S(5<:PRZ[`[gRhhi&m &m &m &mRL\  	
 \   >46 46 46 46 46%.1 46 46 46v *. $L L &	L
 L L L L42Q 2Q 2Q 2Q 2Q 2Q 2Q 2Qj    29   <: : : : :BI : : :(	 	 	 	 	 4 	 	 	 )." "!"	" L" 	"
 5+," 38_" 	" 
"  ~" 5u|+,eEL.AAB" " " "\ !&(! (!!(!	(! 2(! 	(!
 (! 38_(! 	(! 
(! +(! 5<(! (! (! (!V ! 	  L  	 
 5+,  38_  	  
  5<       H 1$"! ! L3 L3 L3 L3 L3") L3 L3 L3^+3 +3 +3 +3 +37 +3 +3 +3\ } } } } } } } }@ s: s: s: s: s:/ s: s: s:l	> 	> 	> 	> 	>ry 	> 	> 	>   
S
 S
 S
 S
 S
5 S
 S
 
S
l   
A
 A
 A
 A
 A
*C A
 A
 
A
H   
W
 W
 W
 W
 W
'@ W
 W
 
W
t X
 X
 X
 X
 X
%> X
 X
 X
v   
w
 w
 w
 w
 w
"; w
 w
 
w
t	 	 	rh   