
     `i                     z   d dl Z d dlZd dlmZ d dlmZmZmZ d dlZd dlm	Z	 ddl
mZ ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZ ddlmZ ddlmZmZ ddlmZmZ ddlm Z m!Z!m"Z" ddl#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z)m*Z*m+Z+m,Z,m-Z- ddl.m/Z/  e            rd dl0m1Z1 d dl2m3Z3m4Z4 nd\  Z1Z3Z4 e            r	d dl5m6Z6m7Z7 nd\  Z7Z6 e8e1e6e7f          Z9dZ: ej;        e<          Z= G d dej	        j>                  Z? G d de,          Z@ G d de(          ZA G d  d!e          ZB G d" d#e$          ZC G d$ d%e	j>                  ZD G d& d'e	j>                  ZE G d( d)e%          ZF G d* d+e*          ZG G d, d-e)          ZH G d. d/e          ZI G d0 d1e+eI          ZJ G d2 d3e&          ZK G d4 d5e'          ZLg d6ZMdS )7    N)cycle)CallableOptionalUnion)nn   )ACT2FN)FlashAttentionKwargs)BaseModelOutputWithPast)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)logging)deprecate_kwarg)is_causal_conv1d_availableis_mamba_ssm_available   )LlamaRotaryEmbeddingapply_rotary_pos_emb)pad_tensor_by_sizereshape_into_chunkssegment_sum)
ZambaAttentionZambaAttentionDecoderLayerZambaForCausalLMZambaForSequenceClassificationZambaHybridDynamicCacheZambaHybridLayerZambaMambaDecoderLayer
ZambaModelZambaRMSNormeager_attention_forward   )Zamba2Config)selective_state_update)mamba_chunk_scan_combined mamba_split_conv1d_scan_combinedNNN)causal_conv1d_fncausal_conv1d_updateNNzZyphra/Zamba2-2.7Bc                   (     e Zd Zd fd	ZddZ xZS )Zamba2RMSNormGatedư>c                     t                                                       t          j        t	          j        |                    | _        || _        || _        d S N)	super__init__r   	Parametertorchonesweightvariance_epsilon
group_size)selfhidden_sizer8   eps	__class__s       }/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/transformers/models/zamba2/modular_zamba2.pyr2   zZamba2RMSNormGated.__init__J   sG    l5:k#:#:;; #$    Nc                    |j         }|                    t          j                  }|?|t          j                            |                    t          j                            z  }|j        ^ }}|| j        z  } |j	        g ||| j        R  }|
                    d                              dd          }|t          j        || j        z             z  } |j	        g ||| j        z  R  }| j        |                    |          z  S )Nr   T)keepdim)dtypetor4   float32r   
functionalsilushaper8   viewpowmeanrsqrtr7   r6   )	r9   hidden_statesgateinput_dtypeprefix_dimslast_dimgroup_counthidden_states_groupvariances	            r=   forwardzZamba2RMSNormGated.forwardP   s   #)%((77)BM,>,>twwu}?U?U,V,VVM!.!4h$/10m0\+\{\DO\\\&**1--222t2DD1EK4K`@`4a4aa0+0]+]{T_?\]]]{]--k::::r>   )r.   r0   )__name__
__module____qualname__r2   rT   __classcell__r<   s   @r=   r-   r-   I   sQ        % % % % % %; ; ; ; ; ; ; ;r>   r-   c                       e Zd ZdS )Zamba2RMSNormNrU   rV   rW    r>   r=   r[   r[   ^           Dr>   r[   c            
           e Zd ZdZej        dfdededej        de	e
         fdZded	ej        d
ej        dej        fdZd Zdde	e         defdZdS )Zamba2HybridDynamicCachea  
    A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
    (which has a constant shape regardless of seq_len).

    This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
    and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
    For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
    while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
    For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
    while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
    and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
    Nconfig
batch_sizerB   devicec           	         || _         |j        | _        d| _        t          |j        |j        z            | _        |j        | _        |j	        | _
        |j        | _        g | _        i | _        i | _        i | _        i | _        i | _        t%          |j                  D ]}t)          j        | j        d|j        z  |j        z  z   | j
        |          | j        |<   t)          j        | j        |j        | j        |          | j        |<   | j        |         dk    r| j                            |           fdt%          |j                  D             | _        fdt%          |j                  D             | _        d S )NFr   rc   rB   hybridc                 D    g | ]}t          j        g gz             S rc   r4   tensor.0_rb   rc   s     r=   
<listcomp>z5Zamba2HybridDynamicCache.__init__.<locals>.<listcomp>   s/    rrrQ%,tj'8HHHrrrr>   c                 D    g | ]}t          j        g gz             S rh   rj   rl   s     r=   ro   z5Zamba2HybridDynamicCache.__init__.<locals>.<listcomp>   s/    tttqEL"
):6JJJtttr>   )rB   layers_block_typehas_previous_stateintmamba_expandr:   intermediate_sizemamba_d_statessm_state_sizemamba_d_convconv_kernel_sizen_mamba_headstransformer_layers_modules_parameters_buffersconv_states
ssm_statesrangenum_hidden_layersr4   zerosmamba_ngroupsmamba_headdimappend	key_cachevalue_cache)r9   ra   rb   rB   rc   is     ` ` r=   r2   z!Zamba2HybridDynamicCache.__init__p   s    
!'!9"'!$V%86;M%M!N!N$2 & 3#1"$v/00 	2 	2A"'+&V-A)AFDX)XX%# # #DQ "'D.0DdFYbhpu" " "DOA %a(H44'..q111rrrrrRWX^XpRqRqrrrtttttTYZ`ZrTsTstttr>   	layer_idxnew_conv_statecache_positionreturnc                 P   | j         |         }|                    d| j        dz
            }|                    dd          }|                    |j                  |d d d d |f<   | j         |                                          | j         |xx         |z  cc<   | j         |         S )Nr   r#   r@   shiftsdims)r   clampry   rollrC   rc   zero_)r9   r   r   r   
conv_states        r=   update_conv_statez*Zamba2HybridDynamicCache.update_conv_state   s     %i0
'--a1F1JKK__BR_88
+9+<+<Z=N+O+O
111aaa'(#))+++###z1###	**r>   c                 j    | j                                          | j                                         d S r0   )r   r   r   )r9   s    r=   resetzZamba2HybridDynamicCache.reset   s1       r>   r   c                     || j         vr| j         d         n|}t          | j                  |k    s#| j        |                                         dk    rdS | j        |         j        d         S )zYReturns the sequence length of the cached states. A layer index can be optionally passed.r   )r{   lenr   numelrG   )r9   r   s     r=   get_seq_lengthz'Zamba2HybridDynamicCache.get_seq_length   sq     3<4CZ2Z2ZD+A..`i	t~)++t~i/H/N/N/P/PTU/U/U1~i(.r22r>   )r   )rU   rV   rW   __doc__r4   float16r$   rs   rB   r   strr2   Tensor
LongTensorr   r   r   r]   r>   r=   r`   r`   b   s          KP-quu u"u03u<AKuaijmanu u u u@
+
+.3l
+LQL\
+	
+ 
+ 
+ 
+     3 3 3c 3 3 3 3 3 3r>   r`   c                       e Zd ZdS )Zamba2RotaryEmbeddingNr\   r]   r>   r=   r   r      r^   r>   r   c                   p    e Zd ZdZ	 	 	 ddedee         dee         dee         f fdZ edd	d
          	 	 	 dde	j
        dedee	j
                 d	ee         deee	j
        e	j
        f                  dee         dee	j
        ee	j
                 eee	j
                          f         fd            Z xZS )Zamba2AttentionaZ  
    Multi-headed attention from 'Attention Is All You Need' paper.

    Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
    The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
    The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
    (see fig. 2 in https://huggingface.co/papers/2405.16712).
    Additionally, replaced
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
    Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this
    layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase
    expressivity with a small memory overhead (see Fig. 2 of https://huggingface.co/papers/2411.15242).
    Nra   r   num_fwd_mem_blocksblock_idc           	         t                                          ||           || _        |j        | _        || _        |j        rt          j        g           | _	        t          j        g           | _
        t          j        g           | _        t          | j                  D ]}||j        z  |k    rt          j        t          j        | j        | j        j        d          t          j        | j        j        | j        d                    }t          j        t          j        | j        | j        j        d          t          j        | j        j        | j        d                    }t          j        t          j        | j        | j        j        d          t          j        | j        j        | j        d                    }n9t          j                    }t          j                    }t          j                    }| j	                            |           | j
                            |           | j                            |           d t+          | j                  D             | _        d S )NFbiasc                     i | ]\  }}||	S r]   r]   rm   indexvalues      r=   
<dictcomp>z,Zamba2Attention.__init__.<locals>.<dictcomp>   s    [[[<5%%[[[r>   )r1   r2   r   hybrid_layer_idslayer_block_mapr   use_shared_attention_adapterr   
ModuleListlinear_q_adapter_listlinear_k_adapter_listlinear_v_adapter_listr   num_mem_blocks
SequentialLinearattention_hidden_sizera   adapter_rankIdentityr   	enumerate	layer_dic)
r9   ra   r   r   r   r   linear_q_adapterlinear_k_adapterlinear_v_adapterr<   s
            r=   r2   zZamba2Attention.__init__   s*    	+++"4%6 . 	D)+r):):D&)+r):):D&)+r):):D&4233 D Dv,,88')}	$"<dk>V]bccc	$+":D<V]bccc( ($ (*}	$"<dk>V]bccc	$+":D<V]bccc( ($ (*}	$"<dk>V]bccc	$+":D<V]bccc( ($$
 (*{}}$'){}}$'){}}$*112BCCC*112BCCC*112BCCCC[[9TEY;Z;Z[[[r>   past_key_valuepast_key_values4.58new_nameversionrL   attention_maskposition_embeddingskwargsr   c                    |j         d d         }g |d| j        R }|                     |          }	|                     |          }
|                     |          }| j        j        rX| j        |         }|	 | j        |         |          z   }	|
 | j	        |         |          z   }
| | j
        |         |          z   }|	                    |                              dd          }	|
                    |                              dd          }
|                    |                              dd          }| j        j        r|\  }}t          |	|
||          \  }	}
||                    |
||          \  }
}t           }| j        j        dk    rt$          | j        j                 } || |	|
||f| j        sdn| j        | j        d|\  }} |j        g |dR                                  }|                     |          }||fS )Nr@   r#   r   eagerg        )dropoutscaling)rG   head_dimq_projk_projv_projra   r   r   r   r   r   rH   	transposeuse_mem_roper   updater"   _attn_implementationr   trainingattention_dropoutr   reshape
contiguouso_proj)r9   rL   r   r   r   r   r   input_shapehidden_shapequery_states
key_statesvalue_statesadapter_layer_idxcossinattention_interfaceattn_outputattn_weightss                     r=   rT   zZamba2Attention.forward   sS    $)#2#.88b8$-88{{=11[[//
{{=11;3 	g $y 9'*W$*DEV*WXe*f*ffL#&Sd&@AR&STa&b&bbJ'*W$*DEV*WXe*f*ffL#((66@@AFF__\22<<QBB
#((66@@AFF;# 	`*HC';L*VY[^'_'_$L*&'6'='=j,Xa'b'b$J(?;+w66"9$+:Z"[$7$7	%
  $}HCC$2HL	%
 	%
 	%
 	%
!\ *k);;;;;;FFHHkk+..L((r>   r(   )rU   rV   rW   r   r$   r   rs   r2   r   r4   r   r`   tupler   r
   rT   rX   rY   s   @r=   r   r      sg        $ $(,0"&'\ '\'\ C='\ %SM	'\
 3-'\ '\ '\ '\ '\ '\R _%0A6RRR
 26>BKO1) 1)|1) 1) !.	1)
 "":;1) &eEL%,,F&GH1) -.1) 
u|Xel3XeEL>Q5RR	S1) 1) 1) SR1) 1) 1) 1) 1)r>   r   c                        e Zd ZdZddedee         f fdZ	 	 ddej	        dee
         deej	                 fd	Zddee
         deej	                 fd
Z	 	 ddee
         deej	                 fdZ xZS )Zamba2MambaMixeru  
    Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
    A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
    ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
    and is why Mamba is called **selective** state spaces)
    Nra   r   c           	      z   t                                                       || _        |j        | _        |j        | _        |j        | _        t          |j	        | j        z            | _
        || _        |j        | _        d| _        t          j                    | _        |j        | _        |j        | _        |j        | _        | j        j        | _        |j        | _        |j        | _        |j        | _        |j        | _        | j
        d| j        z  | j        z  z   | _        t          j        | j        | j        d|j        | j        |j        dz
            | _        | j
        | j        z   | j        z   }t          j        | j        ||j                   | _!        t          j"        tG          j$        | j                            | _%        tG          j&        d| j        dz             }t          j"        tG          j'        |                    | _(        tS          | j
        | j
        | j        z  d          | _*        t          j"        tG          j$        | j                            | _+        t          j        | j
        | j        |j                   | _,        tZ          st\          /                    d	           d S d S )
NrF   r   Tr#   )in_channelsout_channelsr   kernel_sizegroupspaddingr   gh㈵>)r8   r;   a  The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d)0r1   r2   ra   r:   rv   rw   rx   ry   rs   rt   ru   r   use_conv_bias
activationr   SiLUactuse_mem_eff_pathr   n_groupsr   r   rz   	num_heads
chunk_sizetime_step_limittime_step_mintime_step_maxconv_dimConv1dconv1dr   add_bias_linearin_projr3   r4   r5   dt_biasarangelogA_logr-   normDout_projis_fast_path_availableloggerwarning_once)r9   ra   r   projection_sizeAr<   s        r=   r2   zZamba2MambaMixer.__init__#  sg   !-$2 & 3!$V%84;K%K!L!L"#1 799 & 7,,2 +%5#1#1.T]1BTEX1XXi+='!+
 
 
 04=@4>Qy'
 
 
 |EJt~$>$>?? LDNQ.//\%)A,,//
&"t/E/V\`
 
 
	 ej8899	$"8$:JQWQghhh% 	>    	 	r>   rL   cache_paramsr   c                    |j         \  }}}| j        | j        z  }d| j        z  d| j        z  | j        z  z   | j        z   }||j        r|                     |                    d                    }	|	j         d         |z
  dz  }
|
|
| j        | j        | j        g}t          j
        |	|d          \  }}}}}t          ||j        | j                 | j        j                            d          | j        j        | j                  }t          j
        || j        ||gd          \  }}}t          j        | j                                                   }|d d d df         d d d d d f                             d| j        | j                                      t          j                  }|d d d d d f                             dd| j                  }| j        d d d df                             d| j                  }| j        d d d df                             d| j                  }|                    || j        |j         d         | j        z            }|                    || j        |j         d         | j        z            }|                    || j        | j                  }t9          |j        | j                 ||||||d |d
  
        }|                    || j        | j        z            }|                     ||          }|                     |          d d d df         }n'|Dt          j         |dk              s,|j!        }||d d d d d f         z                      |          }|                     |          }t          j        | j                                                   }| j"        i nd	| j"        i}|t          j         |dk              }nd}| j#        r| j$        r||rtK          || j        j                            d          | j        j        | j        |f| j        | j&        d | j        | j        j        | j        j'        | j        j        | j        j        | j        | j        d
dd|\  }}nt          j
        || j        | j        | j        gd          \  }}}|p|(                    dd          }tR          j*        +                    || j,        |j         d         z
  df          }|j        | j                 -                    |           t\          	| j        dvr]| /                    |                     |(                    dd                    (                    dd          d d d |f                   }nst]          |(                    dd          | j        j                            d          | j        j        | j                  (                    dd          d d d |f         }t          j
        || j        ||gd          \  }}}|Dt          j         |dk              s,|j!        }||d d d d d f         z                      |          }ta          |                    ||d| j                  |||                    ||| j        d          |                    ||| j        d          f| j&        | j        d d d| j        dd|\  }}|'|%|j        | j                 -                    |           |                    ||d          }|                     ||          }|                     |          }|S )Nr   r#   r@   dim.rB   T)zr  dt_softplusdt_limitF)r  r   seq_idxr   rmsnorm_weightrmsnorm_epsoutproj_weightoutproj_biasheaddimngroupsnorm_before_gatereturn_final_statesr   )rF   swish)xr6   r   r   )r   r  r  r  r  r  r  )1rG   r   rw   ru   r   rr   r   squeezer   r4   splitr*   r   r   r   r6   r   r   expr  floatexpandr   rC   rD   r  r  rH   r%   r   r  r  allrB   r   r   r   r'   r   r7   r   r   rE   padry   copy_r)   r   r&   )r9   rL   r  r   rb   seq_lenrn   groups_time_state_sized_to_removein_projected_statesd_mlpsplit_projection_dimrM   hidden_states_B_CdtBCr  r  r  hidden_states_reshapedoutrB   projected_statesdt_limit_kwargsinput_not_masked	ssm_state	time_stephidden_states_B_C_tr   scan_outputs                                  r=   cuda_kernels_forwardz%Zamba2MambaMixer.cuda_kernels_forwardb  s    "/!4
GQ!%1D!D$001t}3DtGZ3ZZ]a]kk #(G#"&,,}/D/DQ/G/G"H"H(.r2[@QFE$)5$2H$-Y]Yg#h 05<OQekm0n0n0n-Aq$)2 4!(8"**1-- ! ! #(+!')?AWX# # #M1a
 4:++--...A!!!T3,111d
+222t}dFYZZ]]didq]rrAAAAqqq$J&&r2t}==Bl111dC<077DMJJGqqq$|$++B>>Az4=!'!*2MNNAz4=!'!*2MNNA%2%7%7
DNTXTa%b%b"2'7&   M *..z4>DM;YZZM IImT::M--..qqq$|<CC )%)Na<O2P2P)%+!.111d
1K!K O OPU V V#||M::4:++--...A$($8$@bbzSWSgFhO)#(9^q-@#A#A  #' $ L1 L1<;OTd;O!A$K&..q11K$L" f# ##'9#3 $	 :#'=#7!%!3 M M%*(,#" "$ &%" "YY, 6;[$+T]DNK6 6 62'  +*;*E*Ea*K*K'!#!2!2+d.CFYF_`bFc.cef-g" "J !,T^<BB:NNN#+tFW/W/W(,$5$?$?1$E$EFFPPQRTUVVWXWXWXZb[bZbWbc) )%% )9+55a;;#{199!<<![-#'?	) ) )
  i1ooaaa'k)3% ',k%+-CE[\' ' '#q!
 "-eiRS@S6T6T-)/E%2^AAAqqq$J5O%O$S$STY$Z$ZM)B!&&z7BNNFF:wrBBFF:wrBB*  $f (, L $* * &* *&Y (\-E +DN;AA)LLL)..z7BGG"iiT::mmK00
r>   c                 \   1 |j         \  }}}|j        }|0|j        r)                     |                    d                    }nT|=t          j        |dk              s%||d d d d d f         z                      |          }                     |          }|j         d         d j        z  z
  d j	        z   j
        z  z
   j        z
  dz  }	|                    |	|	 j         j         j        gd          \  }}}
}}|d|j         j                                                 }|                    |j                  }|j        r|
                    d          }
|j         j                 }t          j        |dd          }|j        dk    r|d d dd d f         n||d d d d df<   |j         j                                     |           t          j        |                    |j                   j        j        d d dd d f         z  d          } j        r| j        j        z  }                     |                              |          d d d df         }n|                    dd          }t<          j                             | j!        |j         d         z
  df          }|j         j                                     |                                                     |                              dd                    d d d |d d f         }|Dt          j        |dk              s,|j        }||d d d d d f         z                      |          }nt          j"        | j         j#         j
        f|j        |	          }                                          |                    dd                    dd |f                             dd                    }t          j        | j         j	         j
        z   j	         j
        z  gd          \  }}}t          j$         j%        &                                           }|)|j        r!|j        dk    r|d d d df         n|d d dd d f         d d d df         }|                    dd          '                    ||j         d          j#                  } j(        d
         '                     j(        j         d          j#                  }t
          j        j        )                    ||                    |j                  z             }t          j*        | j+                  }|d         '                     j         j#         j
                                      t
          j,                  }t          j$        |d
         |z            }|-                    | j	        d          dd d d f         }|'                    | j	         j         j	        z  |j         d                   .                                }|-                    |d|j         d                   }|d
         |dd d d f         z  }|-                    |d j#                  }||d
         z  }|j         j                                     |j         j                 |z  |z              |-                    | j	        d          dd d d f         }|'                    | j	         j         j	        z  |j         d                   .                                }|-                    |d|j         d                   }|j         j                                     |j                  }|/                    | j        z   j#         j
                  }|/                    | j        z   j
        d          }t          j0        ||          }|/                    | j         j#                  } j1        d
         '                     j1        j         d          j#                  }|||z  z                       |j                  }|-                    |d          d d d df         }ngt<          j        )                    | j(        z             }t          j*        | j+                  }|-                    ||d j#                  &                                }|-                    ||d j
                  &                                }|-                    ||d j
                  &                                }|2                     j         j	        z  d j                  }|2                     j         j	        z  d j                  } j3        | j3        z  z
   j3        z  1 j1        d
         ti          |1          z  }||d
         z  }|                    |j                  |z  }1 fd||||fD             \  }}}}|5                    dddd          }t          j6        |d          }t          j$        to          |                    }|d d d d d d d d d d d f         |d d d d d d d d d d d f         z  }|                    d          }|d
         |5                    ddddd          d
         z  } |                     d          }!|!d
         |d d d d d f         z                      d          }"t          j$        |d d d d d d dd f         |z
            }#||#5                    dddd          d
         z  }$|$5                    ddddd          d
         |5                    ddddd          dd d d f         z                      d          5                    ddddd          }%|%|j        r|j         j                 d d d df         }&n t          j8        |%d d d df                   }&t          j9        |&|%gd          }%t          j$        to          t<          j                             |d d d d d d df         d                              }'|%5                    ddddd          }(|'d         |(d d d d d df         z                      d          })|)5                    ddddd          }*|*d d d df         |*d d df         }}%t          j$        |          }+|dd d d f         |%d d d d d df         z  },|+5                    dddd          }-|,                    d          |-d
         z  }.|"|.z   }|-                    |d j         j#                  }||z   }1dk    r|d d d |d d d d f         }|-                    ||d          }|'|%|j         j                                     |            :                    ||
          }/ ;                    |/                    |                    }0|0S )Nr#   r@   r   r  r   r   r   .re   ).N).NNr  )r  output_sizec                 <    g | ]}t          |j                  S r]   )r   r   )rm   tpad_sizer9   s     r=   ro   z2Zamba2MambaMixer.torch_forward.<locals>.<listcomp>t  s)    %z%z%z\]&9!Xt&W&W%z%z%zr>      )r#   r   )<rG   rB   rr   r   r   r4   r%  rC   ru   r   rw   r   r!  r   r   r   clonerc   	unsqueezer   r   ndimr'  sumr   r6   r   r   r   r   r   rE   r&  ry   r   r   r"  r  r#  r$  r  softplusr   r   rD   r   r   rH   bmmr  repeat_interleaver   r   permutecumsumr   
zeros_likecatr  r  )2r9   input_statesr  r   rb   r(  rn   rB   r4  r,  rM   rL   r/  r7  r   r0  r1  r  r  dAdBdBxr   ssm_states_reshaped
C_reshapedyr  
D_residualA_cumsumLG_intermediateGM_intermediateMY_diagdecay_statesB_decay_contractionstatesprevious_statesdecay_chunkstates_permutedresult
new_statesstate_decay_outC_times_statesstate_decay_out_permutedY_offr:  contextualized_statesr@  s2   `                                                @r=   torch_forwardzZamba2MambaMixer.torch_forward  s   !-!3
GQ"#(G##||L,@,@,C,CDD)%)NA<M2N2N) ,~aaaDj/I IMMeTT#||L99!'+a$2H.HH1t}K\_c_rKrrtx  uC  C  HI  I(8(>(>t5t~V\^ )? )
 )
%1dM2
 #$/?EEGGI!]%9::I. [~~a(()5dnE
"Z
2BGGG
ANASWXAXAX}QQQ111W'='=^k
111aaa8$(8>>zJJJ %	*--8H8O*P*PSWS^SefgfgfgijlmlmlmfmSn*ntv w w w% 6!T[%55M $ 7 7 : :5 A A!!!T3, O - 7 7! < <]..!*]-@-DDaH 
 (8>>zJJJ $])C)C)M)MaPQ)R)R S STUTUTUW_X_W_abababTb c!-eiPQ@Q6R6R-)/E%2^AAAqqq$J5O%O$S$STY$Z$ZMT^T]D<OP$+5  I !HHT[[1H1HA1N1N%O%OPSU]V]U]P]%^%h%hijlm%n%nooM#k-$:PRVR_bfbuRuw{  xE  HL  H[  x[  :\  bd  e  e  eq!Ytz''))***#(G# &(W\\AAAtSL!!r!!!Q'{111dC<7PBa##**:rx|T]SSBl9-44T\5G5JDMZZG$--b7::bh3G3G.GHHBR!344B/"))$.$-I\]]``glgt`uuA2i=1,--B
 		*dmR88dAAAFAT]DNdm4SUVU\]_U`aallnnA		*b!'"+66AI3aaa<0B *11*b$-PPM}Y//C #DN399'7"<sB   		*dmR88dAAAFAT]DNdm4SUVU\]_U`aallnnA		*b!'"+66A &0@CCAGLLJ",//*t~2Mt}^b^q"r"r
T^ ;T=PRSTTJ	-z::Az4>4=AAA y!((a$-HHA]Q&&**1733A 		*b))!!!T3,7AA ''T\(9::BR!344B)11*gr4=YY__aaM		*gD4GHHNNPPA		*gr43FGGMMOOA##DNdm$CX\Xf#ggA##DNdm$CX\Xf#ggA'DO*CCtVH	*-?x-X-XXJ *ByM9M]())B.A &{%z%z%z%zboqrtuwxay%z%z%z"M1a 		!Q1%%A|A2...H 	+a..))A qqq!!!QQQaaa23a111dAAAqqq!!!8K6LLN""r"**A y\AIIaAq!,D,DY,OON""r"**A 	l]111aaa:%>>CCAFFF !9XaaaAAArssl%;h%FGGL"#l&:&:1aA&F&Fy&Q"Q)11!Q1a@@K}OdOdefhiklnoqrOsOstwy}  @A  @A  @A  uA  PB  B  G  G  LM  G  N  N  V  V  WX  Z[  ]^  `a  cd  e  eF'L,K'"."9$."I!!!TSV,"W"'"26!!!RaR%="A"AY8a@@@F)K0A0A(111aaaQRQRQRTV;BWY_0`0`$a$abbK$nnQ1a;;O!/2_QQQ4QT_5UUZZ_`ZaaF1aA66J *111crc6 2Jqqq"u4EIF $i11OT111oqqq!!!T30GGN'6'>'>q!Q'J'J$#''++.Fy.QQE A		*b$.$-HHAJA!||aaa'111aaa'(		*gr22A$)A'7==iHHHii4((
 !%knnU.C.C D D$$r>   c                     t           r/d| j        j        j        j        v r|                     |||          S |                     |||          S )Ncuda)r  r   r6   rc   typer;  ri  )r9   rL   r  r   s       r=   rT   zZamba2MambaMixer.forward  sR     " 	Zf0C0J0O&O&O,,]L.YYY!!-~NNNr>   r0   r+   )rU   rV   rW   r   r$   r   rs   r2   r4   r   r`   r;  ri  rT   rX   rY   s   @r=   r   r     sJ        = =| = = = = = = =D <@15	T T|T 78T !.	T T T Tn% %AY8Z %qyz  {G  rH % % % %J <@15		O 	O 78	O !.		O 	O 	O 	O 	O 	O 	O 	Or>   r   c                   >     e Zd Zddedee         f fdZddZ xZS )	Zamba2MLPNra   r   c           	      n   t                                                       || _        |j        | _        |j        | _        || _        || _        t          j        | j        d| j        z  |j	                  | _
        t          j        | j        | j        |j	                  | _        t          |j                 | _        t          j        g           | _        t#          | j                  D ]}||j        z  |k    rft          j        t          j        | j        j        | j        j        d          t          j        | j        j        d| j        z  d                    }nt          j                    }| j                            |           |j        }d t1          |          D             | _        dS )aQ  
        This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer
        is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead.
        r   r   Fc                     i | ]\  }}||	S r]   r]   r   s      r=   r   z&Zamba2MLP.__init__.<locals>.<dictcomp>  s    VVV<5%%VVVr>   N)r1   r2   ra   r:   ru   r   r   r   r   r   gate_up_proj	down_projr	   
hidden_actact_fnr   gate_up_proj_adapter_listr   r   r   r   r   r   r   r   r   )r9   ra   r   r   r   gate_up_proj_adapterr   r<   s          r=   r2   zZamba2MLP.__init__  s   
 	!-!'!9"4 Id&6D<R8RY_Yoppp4#94;KRXRhiiiV./)+r):):&t.// 	H 	HA6((H44')}Idk5t{7OV[\\\Idk6D<R8RY^___( ($$
 (*{}}$*112FGGGG 1VV9_;U;UVVVr>   c                    |                      |          }| j        |         }| | j        |         |          z   }t          j        |dd          }|                     |d                   |d         z  }|                     |          }|S )Nr   r@   r  r   r#   )rq  r   ru  r4   chunkrt  rr  )r9   hidden_stater   gate_up_stateoutputs        r=   rT   zZamba2MLP.forward  s    )),77N9-	%(Q(Fy(QR^(_(__M1"==={{=#344}Q7GG--r>   r+   r0   )	rU   rV   rW   r$   r   rs   r2   rT   rX   rY   s   @r=   rn  rn    st        W W| WPXY\P] W W W W W W<       r>   rn  c                   R    e Zd Zddedee         dee         f fdZ eddd	          	 	 	 	 ddej	        dej	        dedeej	                 dee
         dee         deej                 dee         deej        eeej        ej        f                  f         fd            Z xZS )Zamba2AttentionDecoderLayerNra   r   r   c                     || _         t          |j                  }t                                          ||           t          |d||          | _        t          |||          | _        d S )Nr@   )r   r   r   )r   r   )	r   r   r   r1   r2   r   	self_attnrn  feed_forward)r9   ra   r   r   num_gsr<   s        r=   r2   z$Zamba2AttentionDecoderLayer.__init__  sl     V,--+++(2RXcklll%fRZ[[[r>   r   r   r   r   FrL   original_hidden_statesr   output_attentionsr   r   r   c           
          t          j        ||gd          }|                     |          } | j        d||||||d|\  }}	|                     |          }|                     ||          }|f}
|r|
|	fz  }
|
S )a  
        Args:
            hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)`
            original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`.
                This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
                concatenated tensor is then used as input of the pre-attention RMSNorm
                (see fig. 2 in https://huggingface.co/papers/2405.16712).
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
        r@   r  )rL   r   r   r   r  r   r]   )r4   concatenateinput_layernormr  pre_ff_layernormr  )r9   rL   r  r   r   r   r  r   r   self_attn_weightsoutputss              r=   rT   z#Zamba2AttentionDecoderLayer.forward  s    @ )=:P*QWYZZZ,,];;+94> ,
')+/ 3,
 ,
 ,
 ,
(( --m<<))-CC " 	,)++Gr>   r+   )NNFN)rU   rV   rW   r$   r   rs   r2   r   r4   r   r`   boolr   r   r
   r   FloatTensorrT   rX   rY   s   @r=   r}  r}    sQ       \ \| \x} \X`adXe \ \ \ \ \ \ _%0A6RRR 26>B,1:>3 3|3 !&3 	3
 !.3 "":;3 $D>3 &e&673 -.3 
u (51BEDU1U+V"WW	X3 3 3 SR3 3 3 3 3r>   r}  c                   (     e Zd Zdedef fdZ xZS )Zamba2MambaDecoderLayerra   r   c                     t                                          ||           t          ||          | _        t	          |j        |j                  | _        d S )N)ra   r   r;   )r1   r2   r   mambar[   r:   rms_norm_epsr  )r9   ra   r   r<   s      r=   r2   z Zamba2MambaDecoderLayer.__init__1  sR    +++%VyIII
,V-?VEXYYYr>   )rU   rV   rW   r$   rs   r2   rX   rY   s   @r=   r  r  0  sW        Z| Z Z Z Z Z Z Z Z Z Z Zr>   r  c                   |    e Zd Zdedej        def fdZ eddd          	 	 	 	 	 	 	 	 dde	j
        dee	j
                 dee         dee	j
                 dee	j
                 dee         dee         dee         dee	j                 dee	j        eee	j        e	j        f                  f         fd            Z xZS )Zamba2HybridLayershared_transformerlinearr  c                 `    t                                          |||           | `|| _        d S r0   )r1   r2   shared_transfr  )r9   r  r  r  r<   s       r=   r2   zZamba2HybridLayer.__init__8  s6     	+VU;;;"4r>   r   r   r   r   NFrL   r  r   r   causal_maskr  	use_cacher   r   c
           	          |                      |||||||	          }
|
d         }|r|
d         }|                     |          }|                     |||||||	          }
|r|
d         |f|
dd         z   }
|
S )aY  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with
            hidden activations to form the input of the shared transformer layer.
            layer_idx (`int`): layer number.
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
        )r  r   r   r   r  r   r   r#   )transformer_hidden_statesr   r   r  r  r   r   N)r  r  mamba_decoder)r9   rL   r  r   r   r  r   r  r  r   layer_outputsr  r  s                r=   rT   zZamba2HybridLayer.forward?  s    B //#9&+/ 3 0 
 
 %2!$4! 	1 -a 0$(KK0I$J$J!**&?)+/ 3 + 
 
  	V*1-/@AMRSRTRTDUUMr>   )NNNNNFFN)rU   rV   rW   r}  r   r   r  r2   r   r4   r   r   rs   r`   r  r   r   r  rT   rX   rY   s   @r=   r  r  7  s`       5"=5GIy5Yp5 5 5 5 5 5 _%0A6RRR :>#'15.2>B,1$):>> >|> !) 6> C=	>
 !.> el+> "":;> $D>> D>> &e&67> 
u (51BEDU1U+V"WW	X> > > SR> > > > >r>   r  c                   N     e Zd ZU eed<   dZdZddgZdZdZ	dZ
dZdZ fdZ xZS )Zamba2PreTrainedModelra   modelTr}  r  r   c                 >   t                                          |           t          |t                    rdt	          j        t	          j        | j        j                  t          j
        | j        j                  t          j
        | j        j                  z
  z  t          j
        | j        j                  z                                 | j        j                  }|t	          j
        t	          j        |                      z   }|j        j                            |           t	          j        d|j        dz             }|j        j                            t	          j
        |                     |j        j                            d           d S d S )N)minr#   g      ?)r1   _init_weights
isinstancer   r4   r"  randra   rz   mathr  r   r   r   time_step_floorexpm1r  datar'  r  r   r  r  fill_)r9   moduler/  inv_dtr  r<   s        r=   r  z#Zamba2PreTrainedModel._init_weights  sM   f%%%f.// 	%
4;4558DK566$+B[9\9\\^(4;4556  e3e44	  %)U["%5%5$5666FN%%f---Q 01 455AL##EIaLL111HM$$$$$	% 	%r>   )rU   rV   rW   r$   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modules_skip_keys_device_placement_supports_flash_attn_supports_flex_attn_supports_sdpa_is_statefulr  rX   rY   s   @r=   r  r    sz         &*#68QR"3NL% % % % % % % % %r>   r  c                   (   e Zd ZdZdefdZd Z	 	 	 	 	 	 	 	 	 	 ddeej	                 deej
                 deej	                 d	ee         d
eej                 dee         dee         dee         dee         deej	                 deeef         fdZdS )Zamba2Modelzh
    Model consisting of *config.num_hidden_layers* layers.

    Args:
        config: Zamba2Config
    ra   c                 p   t                               |            | _        j        | _        j        | _        t          j        j        j        | j                  | _	        fdt          j                  D             }g }g }j        | _        t          j                  D ]}j        |         dk    r%|                    t          |                     8j        |         dk    rb|                    t          j        | j        j        | j        j        d                     |                    t          |                     t#          |          }t#          |          }t%          |          }|                     |||          }t          j        |          | _        j        | _        t/          j        j                  | _        j        r5j        rt8                              d           t=                    | _        d| _         | !                                 d S )	Nc                 2    g | ]}t          |           S ))r   )r}  )rm   kra   s     r=   ro   z(Zamba2Model.__init__.<locals>.<listcomp>  s'    hhha-fqAAAhhhr>   r  r   rf   Fr   r  ze`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`.)"r  r2   ra   pad_token_idpadding_idx
vocab_sizer   	Embeddingr:   embed_tokensr   r   rq   r   r   r  r   iterr   
get_layersr   layersr   r[   r  final_layernormr   use_long_contextr	  r
  r   
rotary_embgradient_checkpointing	post_init)r9   ra   blocksmamba_layerslinear_layersr   r  s    `     r=   r2   zZamba2Model.__init__  s   &&tV444!. +L):F<NPTP`aahhhh5QWQfKgKghhh!'!9v/00 	R 	RA'*g55##$;Fa$P$P$PQQQQ)!,88$$RYt{/FH_fk%l%l%lmmm##$;Fa$P$P$PQQQL))]++vEEmF++$*$?!,V-?VEXYYY 	<& ##{   4F;;DO&+# 	r>   c           
      &   g }g | _         d| _        t          | j                  D ]\  }}|dk    r| j        dk    r|| _        t	          |          }| j        j        t          | j        j                  z  dk    r/d| d}t          j
        |dz   dz   dz   d	z   d
z             }	| j                             |	           d}
| j        D ]f}|dk    rY|
| j        j        z  |j        k    rAt          j
        dt          |
          z   dz             }| j                             |           |
dz  }
g| j        j        rpd}
| j        D ]f}|dk    rY|
| j        j        z  |j        k    rAt          j
        dt          |
          z   dz             }| j                             |           |
dz  }
g|                    t          |t	          |          t	          |                               |                    t	          |                     |S )Nr   rf   r#   z	^layers\.z\.shared_transformer\.z(?:z3self_attn\.(?:q_proj|k_proj|v_proj|o_proj)\.weight|z1feed_forward\.(?:gate_up_proj|down_proj)\.weight|z,(?:input_layernorm|pre_ff_layernorm)\.weightz)$z>^shared_transformer\.feed_forward\.gate_up_proj_adapter_list\.z\.(?:0|1)\.weight$zg^shared_transformer\.self_attn\.(?:linear_q_adapter_list|linear_k_adapter_list|linear_v_adapter_list)\.)_tied_weights_keysfirst_transformer_layer_idr   rq   nextra   r   r   r   recompiler   r   r   r   r  )r9   r  r  r  r  layer_id
layer_typeblockprefix_patternmain_keys_pattern
adapter_id_layer_typeadapter_patternattn_adapter_patterns                 r=   r  zZamba2Model.get_layers  sp   "$*+'$-d.D$E$E )	2 )	2 HjX%%2a776>D3V;-DK4P0Q0QQTUUU%R(%R%R%RN(*
& !PQ OO J	J
   ) )% +223DEEE!"J'+'= ( (&(22zDKD^7^bgbp7p7p.0j a"%j//!2"7!8/ /O
 !3::?KKK"a

{? ,%&
+/+A 	, 	,K*h66:Hb;bfkft;t;t79z%q&)*oo%6 '<%<8" 8" 4 !% 7 > >?S T T T&!OJJ/tM7J7JDQ]L^L^__````d<001111r>   N	input_idsr   position_idsr   inputs_embedsr  r  output_hidden_statesreturn_dictr   r   c                     ||n| j         j        }||n| j         j        }||n| j         j        }|	|	n| j         j        }	|d u |d uz  rt          d          | j        r%| j        r|rt          	                    d           d}|| 
                    |          }|}t          j        |          }|r@|>||j        d         n|j        d         }t          | j         || j        | j                  }|
I||                    | j                  nd}t          j        |||j        d         z   |j                  }
||
                    d          }|                     |||
          }| j         j        r|                     ||          }nd }|rd	nd }|rd	nd }t1          | j                  D ]q\  }}|r||fz  }| j        r+| j        r$|                     |j        |||||||||
  
        }n ||||||||||
	  	        }|d         }|r|d         ||d         fz  }r|                     |          }|r||fz  }||j        sd|_        t=          ||r|nd ||          }|	r|n|                                S )NzaYou cannot specify both input_ids and inputs_embeds at the same time, and must specify either onezX`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.Fr   )rB   rc   r  r#   ri   r]   )r  r   r   r  r   r  r  r   T)last_hidden_stater   rL   
attentions) ra   r  r  r  use_return_dict
ValueErrorr  r   r	  r
  r  r4   rB  rG   r`   rB   rc   r   r  r  rC  _update_causal_maskr   r  r   r  _gradient_checkpointing_func__call__r  rr   r   to_tuple)r9   r  r   r  r   r  r  r  r  r  r   rL   r  rb   past_seen_tokensr  r   all_hidden_statesall_self_attnsr   layerr  r{  s                          r=   rT   zZamba2Model.forward  s    2C1N--TXT_Tq$8$D  $+Jj 	 "+!6IIDK<Q	%0%<kk$+B]-t";< 	s   & 	4= 	Y 	j   I  --i88M%!&]!;!;  	v0/8/D++-J]^_J`J6t{JVZV`imituuuO! #.  ..9X.YYY 
 #\ "2]5H5K"KTaTh  N )33A66L..~}n]] ;# 	'"&//-"N"N"&"6@BBD0:d )$+ 6 6 "	: "	:Iu# 6!m%55!* t}  $ A AN!*"#%'! ! !&!+A'#1 +$3&7'(;
! 
! 
! *!,M  : #/"}Q'7&99N,,];;   	2-!11&/Q&15O.(+/8BOOd+%	
 
 
 %;vv&//*;*;;r>   )
NNNNNNNNNN)rU   rV   rW   r   r$   r2   r  r   r4   r   r   r`   r  r  r   r   r   rT   r]   r>   r=   r  r    sI        "| " " " "H. . .d 151537>B59$(,0/3&*59v< v<E,-v< !.v< u/0	v<
 "":;v<   12v< D>v< $D>v< 'tnv< d^v< !!12v< 
u--	.v< v< v< v< v< v<r>   r  c                       e Zd ZdS )Zamba2ForCausalLMNr\   r]   r>   r=   r  r  s  r^   r>   r  c                       e Zd ZdS )Zamba2ForSequenceClassificationNr\   r]   r>   r=   r  r  w  r^   r>   r  )r  r  r  r  )Nr  r  	itertoolsr   typingr   r   r   r4   r   activationsr	   modeling_flash_attention_utilsr
   modeling_outputsr   modeling_utilsr   r   processing_utilsr   utilsr   utils.deprecationr   utils.import_utilsr   r   llama.modeling_llamar   r   mamba2.modeling_mamba2r   r   r   zamba.modeling_zambar   r   r   r   r   r   r   r    r!   r"   configuration_zamba2r$   +mamba_ssm.ops.triton.selective_state_updater%   !mamba_ssm.ops.triton.ssd_combinedr&   r'   causal_conv1dr)   r*   r%  r  _CONFIG_FOR_DOC
get_loggerrU   r	  Moduler-   r[   r`   r   r   r   rn  r}  r  r  r  r  r  r  __all__r]   r>   r=   <module>r     sP     				       , , , , , , , , , ,        ! ! ! ! ! ! B B B B B B 7 7 7 7 7 7 F F F F F F F F & & & & & &      1 0 0 0 0 0        N M M M M M M M Y Y Y Y Y Y Y Y Y Y                        / . . . . .  kRRRRRRmmmmmmmmmZjW57W 8DDDDDDDDD-7**46FH\]^^  '		H	%	%; ; ; ; ; ; ; ;*	 	 	 	 	L 	 	 	D3 D3 D3 D3 D36 D3 D3 D3N	 	 	 	 	0 	 	 	k) k) k) k) k)n k) k) k)\iO iO iO iO iOry iO iO iOX' ' ' ' '	 ' ' 'T< < < < <"< < < <~Z Z Z Z Z4 Z Z ZG G G G G( G G GT% % % % %O % % %:R< R< R< R< R<*3 R< R< R<j	 	 	 	 	( 	 	 		 	 	 	 	&D 	 	 	  r>   