
     `iOl                        d Z ddlmZ ddlZddlmc mZ ddlmZ ddlm	Z	 ddl
mZmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZ ddlmZ ddlmZ ddlmZmZ ddlmZ ddlm Z m!Z! ddl"m#Z# ddl$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z-m.Z.  ej/        e0          Z1 G d de#          Z2 G d de-          Z3 G d de          Z4 G d dej5                  Z6 G d de%          Z7 G d de.          Z8 G d  d!e&e          Z9 G d" d#e,          Z: G d$ d%e+          Z; G d& d'e'          Z< G d( d)e)          Z= G d* d+e*          Z> G d, d-e(          Z?g d.Z@dS )/zPyTorch MiniMax model.    )OptionalN)nn   )ACT2FN)CacheDynamicCache)layer_type_validation)create_causal_mask!create_sliding_window_causal_mask)FlashAttentionKwargs)GradientCheckpointingLayer)MoeModelOutputWithPast)Unpack)TransformersKwargslogging)deprecate_kwarg)OutputRecordercheck_model_inputs   )MixtralConfig)
MixtralAttentionMixtralDecoderLayerMixtralForCausalLMMixtralForQuestionAnswering MixtralForSequenceClassificationMixtralForTokenClassificationMixtralModelMixtralPreTrainedModelMixtralRMSNormMixtralSparseMoeBlockc                   4     e Zd ZdZ	 	 	 	 	 	 	 	 d fd	Z xZS )MiniMaxConfiga  
    This is the configuration class to store the configuration of a [`MiniMaxModel`]. It is used to instantiate an
    MiniMax model according to the specified arguments, defining the model architecture. Instantiating a configuration
    with the defaults will yield a similar configuration to that of the MiniMax.

    [MiniMaxAI/MiniMax-Text-01-hf](https://huggingface.co/MiniMaxAI/MiniMax-Text-01-hf)

    Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
    documentation from [`PretrainedConfig`] for more information.


    Args:
        vocab_size (`int`, *optional*, defaults to 32000):
            Vocabulary size of the MiniMax model. Defines the number of different tokens that can be represented by the
            `inputs_ids` passed when calling [`MiniMaxModel`]
        hidden_size (`int`, *optional*, defaults to 4096):
            Dimension of the hidden representations.
        intermediate_size (`int`, *optional*, defaults to 14336):
            Dimension of the MLP representations.
        num_hidden_layers (`int`, *optional*, defaults to 32):
            Number of hidden layers in the Transformer encoder.
        num_attention_heads (`int`, *optional*, defaults to 32):
            Number of attention heads for each attention layer in the Transformer encoder.
        num_key_value_heads (`int`, *optional*, defaults to 8):
            This is the number of key_value heads that should be used to implement Grouped Query Attention. If
            `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
            `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
            converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
            by meanpooling all the original heads within that group. For more details, check out [this
            paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
        head_dim (`int`, *optional*, defaults to `hidden_size // num_attention_heads`):
            The attention head dimension.
        hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
            The non-linear activation function (function or string) in the decoder.
        max_position_embeddings (`int`, *optional*, defaults to `4096*32`):
            The maximum sequence length that this model might ever be used with. MiniMax's sliding window attention
            allows sequence of up to 4096*32 tokens.
        initializer_range (`float`, *optional*, defaults to 0.02):
            The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
        rms_norm_eps (`float`, *optional*, defaults to 1e-05):
            The epsilon used by the rms normalization layers.
        use_cache (`bool`, *optional*, defaults to `True`):
            Whether or not the model should return the last key/values attentions (not used by all models). Only
            relevant if `config.is_decoder=True`.
        pad_token_id (`int`, *optional*):
            The id of the padding token.
        bos_token_id (`int`, *optional*, defaults to 1):
            The id of the "beginning-of-sequence" token.
        eos_token_id (`int`, *optional*, defaults to 2):
            The id of the "end-of-sequence" token.
        tie_word_embeddings (`bool`, *optional*, defaults to `False`):
            Whether the model's input and output word embeddings should be tied.
        rope_theta (`float`, *optional*, defaults to 1000000.0):
            The base period of the RoPE embeddings.
        sliding_window (`int`, *optional*):
            Sliding window attention window size. If not specified, will default to `4096`.
        attention_dropout (`float`, *optional*, defaults to 0.0):
            The dropout ratio for the attention probabilities.
        num_experts_per_tok (`int`, *optional*, defaults to 2):
            The number of experts to route per-token, can be also interpreted as the `top-k` routing
            parameter
        num_local_experts (`int`, *optional*, defaults to 8):
            Number of experts per Sparse MLP layer.
        output_router_logits (`bool`, *optional*, defaults to `False`):
            Whether or not the router logits should be returned by the model. Enabling this will also
            allow the model to output the auxiliary loss. See [here]() for more details
        router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
            The aux loss factor for the total loss.
        router_jitter_noise (`float`, *optional*, defaults to 0.0):
            Amount of noise to add to the router.
        layer_types (`list`, *optional*):
            Attention pattern for each layer.
        block_size (`int`, *optional*, defaults to 256):
            The length of each attention block, determining how queries, keys, and values
            are grouped and processed for intra- and inter-block attention.
        full_attn_alpha_factor (`float`, *optional*, defaults to 1):
            Weight for residual value in residual connection after normal attention.
        full_attn_beta_factor (`float`, *optional*, defaults to 1):
            Weight for hidden state value in residual connection after normal attention.
        linear_attn_alpha_factor (`float`, *optional*, defaults to 1):
            Weight for residual value in residual connection after lightning attention.
        linear_attn_beta_factor (`float`, *optional*, defaults to 1):
            Weight for hidden state value in residual connection after lightning attention.
        mlp_alpha_factor (`float`, *optional*, defaults to 1):
            Weight for residual value in residual connection after MLP.
        mlp_beta_factor (`float`, *optional*, defaults to 1):
            Weight for hidden state value in residual connection after MLP.

    ```python
    >>> from transformers import MiniMaxModel, MiniMaxConfig

    >>> # Initializing a MiniMax style configuration
    >>> configuration = MiniMaxConfig()

    >>> # Initializing a model from the MiniMax style configuration
    >>> model = MiniMaxModel(configuration)

    >>> # Accessing the model configuration
    >>> configuration = model.config
    ```N      c	                 2    t                      j        di |	 || _        || _        || _        || _        || _        || _        || _        || _	        | j        #d t          | j                  D             | _        t          | j        | j                   d S )Nc                 @    g | ]}t          |d z   dz            rdndS )r$   r   full_attentionlinear_attention)bool).0is     /home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/transformers/models/minimax/modular_minimax.py
<listcomp>z*MiniMaxConfig.__init__.<locals>.<listcomp>   sA          RSD!a%1$5$5M  ;M          )super__init__layer_types
block_sizefull_attn_alpha_factorfull_attn_beta_factorlinear_attn_alpha_factorlinear_attn_beta_factormlp_alpha_factormlp_beta_factorrangenum_hidden_layersr	   )selfr2   r3   r4   r5   r6   r7   r8   r9   super_kwargs	__class__s             r,   r1   zMiniMaxConfig.__init__   s     	((<(((&$&<#%:"(@%'>$ 0.#   W\]a]sWtWt     D 	d.0FGGGGGr.   )Nr#   r$   r$   r$   r$   r$   r$   )__name__
__module____qualname____doc__r1   __classcell__r>   s   @r,   r"   r"   5   sn        c cN  !" !H H H H H H H H H Hr.   r"   c                       e Zd ZdS )MiniMaxRMSNormNr?   r@   rA   r/   r.   r,   rF   rF              Dr.   rF   c                   ~     e Zd Z fdZd ZdefdZ fdZdef fdZd Z	defd	Z
d
ej        fdZdefdZ xZS )MiniMaxCachec                 V    t                                                       g | _        d S N)r0   r1   linear_cacher<   r>   s    r,   r1   zMiniMaxCache.__init__   s'    02r.   c                     t          t          | j                  |dz             D ]}| j                            g            || j        |<   d S )Nr$   )r:   lenrM   append)r<   	layer_idxrM   _s       r,   set_linear_cachezMiniMaxCache.set_linear_cache   sW    s4,--y1}== 	) 	)A$$R(((('3)$$$r.   rR   c                 F    |t          |           k     r| j        |         S d S rL   )rP   rM   r<   rR   s     r,   get_linear_cachezMiniMaxCache.get_linear_cache   s&    s4yy  $Y//tr.   c                     t          t                                                      t          | j                            S rL   )maxr0   __len__rP   rM   rN   s    r,   rZ   zMiniMaxCache.__len__   s,    577??$$c$*;&<&<===r.   c                     |t          | j                  k     r| j        |         g k    r| j        |         fS t                                          |          S rL   )rP   rM   r0   __getitem__)r<   rR   r>   s     r,   r\   zMiniMaxCache.__getitem__   sU    s4,----$2CI2NRT2T2T%i022ww""9---r.   c              #   \   K   t          t          |                     D ]}| |         V  d S rL   )r:   rP   rV   s     r,   __iter__zMiniMaxCache.__iter__   s@      s4yy)) 	" 	"Iy/!!!!	" 	"r.   repeatsc                     t          t          |                     D ]^}| j        |         g k    r+| j        |                             |d          | j        |<   >| j        |                             |           _d S )Nr   dim)r:   rP   rM   repeat_interleavelayersbatch_repeat_interleave)r<   r_   rR   s      r,   re   z$MiniMaxCache.batch_repeat_interleave   s    s4yy)) 	H 	HI +r11/3/@/K/]/]^ekl/]/m/m!),,I&>>wGGGG		H 	Hr.   indicesc                     t          t          |                     D ]Q}| j        |         g k    r| j        |         |df         | j        |<   1| j        |                             |           Rd S )N.)r:   rP   rM   rd   batch_select_indices)r<   rf   rR   s      r,   rh   z!MiniMaxCache.batch_select_indices   s    s4yy)) 	E 	EI +r11/3/@/KGUXL/Y!),,I&;;GDDDD		E 	Er.   
max_lengthc                      t          d          )Nz*MiniMaxCache doesnot support `crop` method)RuntimeError)r<   ri   s     r,   cropzMiniMaxCache.crop   s    GHHHr.   )r?   r@   rA   r1   rT   intrW   rZ   r\   r^   re   torchTensorrh   rl   rC   rD   s   @r,   rJ   rJ      s       3 3 3 3 34 4 4#    
> > > > >.S . . . . . .
" " "Hs H H H HEEL E E E EIs I I I I I I I Ir.   rJ   c                   L    e Zd Zdedef fdZd Zd Z eddd	          	 	 dde	j
        dee	j
        e	j
        f         dee	j
                 dee         dee	j                 dee         dee	j
        ee	j
                 eee	j
                          f         fd            Z xZS )MiniMaxLightningAttentionconfigrR   c                 |   t                                                       || _        t          |dd           p|j        |j        z  | _        |j        | _        |j        | _        |j        | _        t          |j
                 | _        t          | j        | j        z            | _        t          j        |j        | j        | j        z  dz  d          | _        t          j        | j        | j        z  |j        d          | _        t          j        |j        | j        | j        z  d          | _        |                                 }|                     |          \  }}}|                     d|           |                     d|           |                     d|           |                     d|           d S )	Nhead_dimr   F)bias
slope_ratequery_decay	key_decaydiagonal_decay)r0   r1   rR   getattrhidden_sizenum_attention_headsrt   r;   r3   r   
hidden_actact_fnrF   normr   Linearqkv_projout_projoutput_gateget_slope_ratedecay_factorsregister_buffer)r<   rr   rR   rv   rw   rx   ry   r>   s          r,   r1   z"MiniMaxLightningAttention.__init__   s   "
D99mV=OSYSm=m#)#= !'!9 +V./"4=43K#KLL		&"4d6NQUQ^6^ab6binooo	$":T]"JFL^ejkkk9V%79QTXTa9ahmnnn((**
151C1CJ1O1O.Y\:666]K888[)444-~>>>>>r.   c                     ddd| j         z  z  z  }t          j        | j                   dz   }d| j        | j        dz
  dz   z  z
  dz   }||z  }||z  }|d d d d f         }|S )Nr$   r      gh㈵>)r|   rn   arangerR   r;   )r<   baseexponentfactorrates        r,   r   z(MiniMaxLightningAttention.get_slope_rate  s~    A!d6678< 899A=T^t'='AD'HIIDPX~f}AAAtTM"r.   c                    t          j        | j                  dz   }t          j        | |d d d f         z            }t          j        | | j        |d d d f         z
  z            }|d d d f         |d d d f         z
  }|d d d d d d f         }||z  }t          j        |dk    | t          d                    }t          j        |          }|||fS )Nr$   r   z-inf)rn   r   r3   expwherefloat)r<   rv   block_size_rangerw   rx   ry   s         r,   r   z'MiniMaxLightningAttention.decay_factors  s     <881<i.>qqq$w.G GHHIzkT_?OPQPQPQSWPW?X-XYZZ	)!!!T'25EdAAAg5NN'dAAAqqq(89#n4^q%8>/5QW==YY>22I~55r.   past_key_valuepast_key_values4.58new_nameversionNhidden_statesposition_embeddingsattention_maskcache_positionkwargsreturnc                 	   |j         \  }}}	|| j        z   dz
  | j        z  }
|                     |                     |                    }|                    ||| j        d| j        z            }t          j        || j        d          \  }}}|	                    dd          }|	                    dd          }|	                    dd          }d }||
                    | j                  }|t          j        || j        | j        | j                                      |          }|]|                    t          j                  }|                    |                    d                              d           d          }g }t#          |
          D ]a}|| j        z  }t%          || j        z   |          }||z
  }|d d d d ||f         }|d d d d ||f         }|d d d d ||f         }| j        d d d |f         }| j        d d | d f         }| j        d d d d d |d |f         }t          j        | j         |z            }t          j        ||	                    dd                    }t          j        ||z  |          }t          j        ||z  |          }||z   }|                    |           t          j        ||z  	                    dd          |          } ||z  | z   }cnt          j        | j                   }!g }t#          |          D ]}|d d d d ||dz   f         }|d d d d ||dz   f         }|d d d d ||dz   f         }t          j        |	                    dd          |          }"|!|z  |"z   }t          j        ||          }|                    |           t          j        |d          }|	                    dd          }|                    ||| j        | j        z            }|                     |          }t9          j        |                     |                    |z  }|                     |          }||                     | j        |           ||fS )	Nr$   r   ra   r   )dtyper   )!shaper3   r~   r   reshaper|   rt   rn   split	transposerW   rR   zerostor)   masked_fill	unsqueezer:   minrw   rx   ry   r   rv   matmulrQ   catr   Fsigmoidr   r   rT   )#r<   r   r   r   r   r   r   
batch_sizeseq_lenr{   
num_blocks
qkv_statesquery_states
key_statesvalue_statesattn_weights_interattn_outputr+   	start_idxend_idxcurrent_block_sizecurrent_query_statescurrent_key_statescurrent_value_statescurrent_query_decaycurrent_key_decaycurrent_diagonal_decayblock_decayattn_weights_intraattn_output_intraattn_output_intercurrent_attn_outputnext_attn_weights_interratiocurrent_attn_weights_inters#                                      r,   forwardz!MiniMaxLightningAttention.forward  s5    ,9+>(
G[/!3G
[[}!=!=>>
''
GT=UWX[_[hWhii
16Z\]1^1^1^.j,#--a33))!Q//
#--a33 "&!0!A!A$.!Q!Q%!&Z9QSWS`bfbo!p!p!s!s" "
 )!/!2!2!2!D!D+779Q9QRS9T9T9^9^_a9b9b8bdeffK:&& ` `/	i$/97CC%,y%8"'3AAAqqq)G:K4K'L$%/111i6G0G%H"'3AAAqqq)G:K4K'L$&*&6qqq:M;M:M7M&N#$(N1117I6I6J6J3J$K!)-)<QQQCVDVCVXkYkXk=k)l&#i(8;M(MNN &+\2FHZHdHdegikHlHl%m%m"$)L1CF\1\^r$s$s! %*L1EH[1[]o$p$p! '8:K&K#""#6777 +0,'*;;FFr2NNPd+ +' &8+%EH_%_"";`@ It.//EK7^^ 	8 	8'3AAAqqq!a!e)O'D$%/111a!a%i%@"'3AAAqqq!a!e)O'D$-2\:L:V:VWY[]:^:^`t-u-u*%*-?%?B\%\"&+l3GI[&\&\#""#67777 i444 "++Aq11!))*gt?WZ^Zg?ghhii,,i 0 0 ? ?@@;NmmK00 &,,T^=OPPP...r.   )NN)r?   r@   rA   r"   rm   r1   r   r   r   rn   ro   tupler   r   
LongTensorr   r   r   rC   rD   s   @r,   rq   rq      sI       ?} ? ? ? ? ? ? ?,	 	 	6 6 6 _%0A6RRR ,059`/ `/|`/ #5<#=>`/ !.	`/
 "%`/ !!12`/ -.`/ 
u|Xel3XeEL>Q5RR	S`/ `/ `/ SR`/ `/ `/ `/ `/r.   rq   c                       e Zd ZdS )MiniMaxAttentionNrG   r/   r.   r,   r   r   ~  rH   r.   r   c                       e Zd ZdS )MiniMaxSparseMoeBlockNrG   r/   r.   r,   r   r     rH   r.   r   c                       e Zd Zdedef fdZ eddd          	 	 	 	 	 	 	 dd
ej        de	ej        ej        f         de
ej                 de
ej                 de
e         de
e         de
e         de
e         de
ej                 dee         de	ej        e
e	ej        ej        f                  f         fd            Z xZS )MiniMaxDecoderLayerrr   rR   c                 |   t                                          ||           || _        |j        |         | _        |j        | _        |j        | _        | j        dk    r/t          ||          | _        |j	        | _
        |j        | _        d S t          ||          | _        |j        | _
        |j        | _        d S )Nr(   )r0   r1   rR   r2   
layer_typer8   r9   rq   	self_attnr6   attn_alpha_factorr7   attn_beta_factorr   r4   r5   )r<   rr   rR   r>   s      r,   r1   zMiniMaxDecoderLayer.__init__  s    +++" ,Y7 & 7%5?0006vyIIDN%+%DD"$*$BD!!!-fi@@DN%+%BD"$*$@D!!!r.   r   r   r   r   NFr   r   r   position_idsoutput_attentionsoutput_router_logits	use_cacher   r   r   c
                    |                      |          }|} | j        d||||||||	d|
\  }}|| j        z  || j        z  z   }|                     |          }|}|                     |          \  }}|| j        z  || j        z  z   }|S )a  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
            attention_mask (`torch.Tensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_values (`Cache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_router_logits (`bool`, *optional*):
                Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
                should not be returned during inference.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence.
            kwargs (`dict`, *optional*):
                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
                into the model
        )r   r   r   r   r   r   r   r   r/   )input_layernormr   r   r   post_attention_layernormblock_sparse_moer8   r9   )r<   r   r   r   r   r   r   r   r   r   r   residualrS   s                r,   r   zMiniMaxDecoderLayer.forward  s    N ,,];;  *4> 

' 3)%+/)

 

 

 

q !4#99MDLa<aa 55mDD 00??q 4#88=4K_;__r.   )NNNFFFN)r?   r@   rA   r"   rm   r1   r   rn   ro   r   r   r   r   r)   r   r   FloatTensorr   rC   rD   s   @r,   r   r     so       A} A A A A A A A" _%0A6RRR
 2637+/,1/4$)59= =|= #5<#=>= !.	=
 u/0= "%= $D>= 'tn= D>= !!12= -.= 
u (51BEDU1U+V"WW	X= = = SR= = = = =r.   r   c                   8    e Zd ZdZ eed          eeegdZ	dS )MiniMaxPreTrainedModelFr$   )index)router_logitsr   
attentionsN)
r?   r@   rA   _can_compile_fullgraphr   r   r   r   rq   _can_record_outputsr/   r.   r,   r   r     sB        "'(=QGGG,')BC r.   r   c                       e Zd Ze	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 dee         deej	                 dee
         dee
         d	eej                 d
ee         defd            ZdS )MiniMaxModelN	input_idsr   r   r   inputs_embedsr   r   r   r   r   c	                    |d u |d uz  rt          d          |r|t                      }n7|r5t          |t                    s t          dt          |           d          ||                     |          }|B||                                nd}
t          j        |
|
|j        d         z   |j	                  }||
                    d          }| j        j        t          nt          } || j        |||||          }|}|                     ||          }| j        D ]"}|j        dk    r|}n|} ||f||||||d	|	}#|                     |          }t'          ||
          S )Nz:You must specify exactly one of input_ids or inputs_embedszSMiniMax uses cache of its own and is not compatible with `past_key_values` of type .r   r$   )device)rr   input_embedsr   r   r   r   r'   )r   r   r   r   r   r   )last_hidden_stater   )
ValueErrorrJ   
isinstancetypeembed_tokensget_seq_lengthrn   r   r   r   r   rr   sliding_windowr
   r   
rotary_embrd   r   r   r   )r<   r   r   r   r   r   r   r   r   r   past_seen_tokensmask_functioncausal_maskr   r   decoder_layerinput_attention_masks                    r,   r   zMiniMaxModel.forward  s    -t";< 	[YZZZ 	0*nnOO 	z/<HH 	~fjkzf{f{~~~     --i88M!CRC^==???de"\ "2]5H5K"KTaTh  N )33A66L.2k.H.P**Vw#m;&))+%
 
 
 & #oom\JJ![ 	 	M'+;;;'2$$ (6$)M	$73) /#-	 	 	 	MM 		-00%++
 
 
 	
r.   )NNNNNNNN)r?   r@   rA   r   r   rn   r   ro   rJ   r   r)   r   r   r   r   r/   r.   r,   r   r     s         1515372659$(,059G
 G
E,-G
 !.G
 u/0	G

 ",/G
   12G
 D>G
 $D>G
 !!12G
 +,G
 
 G
 G
 G
 G
 G
 G
r.   r   c                        e Zd Z fdZ xZS )MiniMaxForCausalLMc                 6     t                      j        di |S )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from transformers import AutoTokenizer, MiniMaxForCausalLM

        >>> model = MiniMaxForCausalLM.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("MiniMaxAI/MiniMax-Text-01-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```r/   )r0   r   )r<   r=   r>   s     r,   r   zMiniMaxForCausalLM.forward/  s!    . uww.....r.   )r?   r@   rA   r   rC   rD   s   @r,   r   r   .  s8        / / / / / / / / /r.   r   c                       e Zd ZdS ) MiniMaxForSequenceClassificationNrG   r/   r.   r,   r  r  I  rH   r.   r  c                       e Zd ZdS )MiniMaxForTokenClassificationNrG   r/   r.   r,   r  r  M  rH   r.   r  c                       e Zd ZdS )MiniMaxForQuestionAnsweringNrG   r/   r.   r,   r  r  Q  rH   r.   r  )r"   r   r   r   r  r  r  )ArB   typingr   rn   torch.nn.functionalr   
functionalr   activationsr   cache_utilsr   r   configuration_utilsr	   masking_utilsr
   r   modeling_flash_attention_utilsr   modeling_layersr   modeling_outputsr   processing_utilsr   utilsr   r   utils.deprecationr   utils.genericr   r   mixtral.configuration_mixtralr   mixtral.modeling_mixtralr   r   r   r   r   r   r   r   r   r    
get_loggerr?   loggerr"   rF   rJ   Modulerq   r   r   r   r   r   r   r  r  r  __all__r/   r.   r,   <module>r     sJ                            ! ! ! ! ! ! . . . . . . . . 8 8 8 8 8 8 R R R R R R R R B B B B B B 9 9 9 9 9 9 6 6 6 6 6 6 & & & & & & 0 0 0 0 0 0 0 0 0 0 0 0 0 0 ? ? ? ? ? ? ? ? 9 9 9 9 9 9                        
	H	%	%@H @H @H @H @HM @H @H @HF	 	 	 	 	^ 	 	 	+I +I +I +I +I< +I +I +I\Q/ Q/ Q/ Q/ Q/	 Q/ Q/ Q/h	 	 	 	 	' 	 	 		 	 	 	 	1 	 	 	P P P P P-/I P P Pf    3   I
 I
 I
 I
 I
< I
 I
 I
X/ / / / /+ / / /6	 	 	 	 	'G 	 	 		 	 	 	 	$A 	 	 		 	 	 	 	"= 	 	 	  r.   