
     `i                        d dl mZmZmZ d dlZd dlZd dlmZ d dl	m
Z
mZ ddlmZ ddlmZmZmZ ddlmZ ddlmZ dd	lmZmZ dd
lmZ ddlmZ ddlmZmZm Z m!Z!m"Z" ddl#m$Z$m%Z% ddl&m'Z'm(Z( ddl)m*Z* ddl+m,Z,m-Z-m.Z. ddl/m0Z0 ddl1m2Z2  G d dej3                  Z4 G d dej3                  Z5dej6        de7dej6        fdZ8	 dFdej3        dej6        dej6        d ej6        d!eej6                 d"e9d#e9d$e*e,         fd%Z:d& Z;dGd'Z< G d( d)ej3                  Z= G d* d+ej3                  Z> G d, d-e          Z? G d. d/e          Z@e- G d0 d1e(                      ZA G d2 d3eA          ZBe- G d4 d5eA                      ZC	 	 dHd6eDe7e7f         d7e9d8e7d!eejE                 d9e7dejF        fd:ZGe- G d; d<eA                      ZHd=ej6        d>e7d?e7fd@ZI e-dAB           G dC dDeAe                      ZJg dEZKdS )I    )CallableOptionalUnionN)OutputRecordercheck_model_inputs   )ACT2FN)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)create_causal_mask)_prepare_4d_attention_mask#_prepare_4d_attention_mask_for_sdpa)FlashAttentionKwargs)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPast)BaseModelOutputWithPastAndCrossAttentionsSeq2SeqLMOutputSeq2SeqModelOutput)ROPE_INIT_FUNCTIONSdynamic_rope_update)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tuple)deprecate_kwarg   )MoonshineConfigc                   B     e Zd Z fdZdej        dej        fdZ xZS )MoonshineEncoderMLPc                 
   t                                                       || _        t          |         | _        t          j        |j        |j                  | _	        t          j        |j        |j                  | _
        d S Nsuper__init__configr	   activation_fnnnLinearhidden_sizeintermediate_sizefc1fc2selfr*   
hidden_act	__class__s      /home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/transformers/models/moonshine/modeling_moonshine.pyr)   zMoonshineEncoderMLP.__init__4   sc    #J/9V/1IJJ9V5v7IJJ    hidden_statesreturnc                     |                      |          }|                     |          }|                     |          }|S r&   )r0   r+   r1   )r3   r8   s     r6   forwardzMoonshineEncoderMLP.forward;   s=    //**=99//r7   __name__
__module____qualname__r)   torchTensorr;   __classcell__r5   s   @r6   r$   r$   3   sc        K K K K KU\ el        r7   r$   c                   B     e Zd Z fdZdej        dej        fdZ xZS )MoonshineDecoderMLPc                    t                                                       || _        t          |         | _        t          j        |j        |j        dz            | _	        t          j        |j        |j                  | _
        d S )N   r'   r2   s      r6   r)   zMoonshineDecoderMLP.__init__C   sh    #J/9V/1IA1MNN9V5v7IJJr7   r8   r9   c                     |                      |          }|                    dd          \  }}|                     |          |z  }|                     |          }|S )NrG   dim)r0   chunkr+   r1   )r3   r8   gates      r6   r;   zMoonshineDecoderMLP.forwardJ   s_    //+11!1<<t**400=@//r7   r<   rC   s   @r6   rE   rE   B   sc        K K K K KU\ el        r7   rE   r8   n_repr9   c                     | j         \  }}}}|dk    r| S | dddddddddf                             |||||          } |                     |||z  ||          S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r!   N)shapeexpandreshape)r8   rN   batchnum_key_value_headsslenhead_dims         r6   	repeat_kvrW   R   s    
 2?1D.Ehzz!!!!QQQaaa"23::5BUW\^bdlmmM  (;e(CT8TTTr7           modulequerykeyvalueattention_maskscalingdropoutkwargsc                 R   t          || j                  }t          || j                  }	t          j        ||                    dd                    |z  }
|$|d d d d d d d |j        d         f         }|
|z   }
t          j                            |
dt          j	                  
                    |j                  }
t          j                            |
|| j                  }
t          j        |
|	          }|                    dd                                          }||
fS )NrG   r   rI   )rK   dtype)ptrainingr!   )rW   num_key_value_groupsr@   matmul	transposerP   r,   
functionalsoftmaxfloat32torc   r_   re   
contiguous)rY   rZ   r[   r\   r]   r^   r_   r`   
key_statesvalue_statesattn_weightscausal_maskattn_outputs                r6   eager_attention_forwardrs   ^   s    3 ;<<JUF$?@@L<z';';Aq'A'ABBWLL!$QQQ111.D
0@0D.D%DE#k1=((2U](SSVVW\WbccL=((6?([[L,|\::K''1--88::K$$r7   c                     | ddddf         }| ddddf         }t          j        | |fd                              d          S )	z*Rotates half the hidden dims of the input..r   NrG   r!   rI   rJ   rb   )r@   stackflatten)xx1x2s      r6   rotate_halfrz   x   sQ    	
319B	
319B;Ryb)))11"555r7   c                 T   |                     |          }|                     |          }|dd|j        d         dz  f                             dd          }|dd|j        d         dz  f                             dd          }|j        d         }| dd|f         | d|df         }}|dd|f         |d|df         }
}	||z  t          |          |z  z   }|	|z  t          |	          |z  z   }t	          j        ||gd          }t	          j        ||
gd          }||fS )a  Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    .NrI   rG   rJ   )	unsqueezerP   repeat_interleaverz   r@   cat)qkcossinposition_idsunsqueeze_dim
rotary_dimq_rotq_passk_rotk_passq_embedk_embeds                r6   apply_rotary_pos_embr      s\   ( --
&
&C
--
&
&C c'SYr]a'''
(
:
:1"
:
E
EC
c'SYr]a'''
(
:
:1"
:
E
EC 2Jc;J;&'3
+;)<6Ec;J;&'3
+;)<6E s{{511C78Gs{{511C78G i&)r222Gi&)r222GGr7   c                   |    e Zd ZdZdededededef
 fdZ edd	d
          	 	 	 	 	 dde	j
        deee	j
        e	j
        f                  dee	j
                 d	ee         dee	j                 dee	j
                 dee         dee	j
        ee	j
                 eee	j
                          f         fd            Z xZS )MoonshineAttentionz=Multi-headed attention from 'Attention Is All You Need' paperr*   	layer_idx	is_causalnum_attention_headsrT   c                 Z   t                                                       |                    ||d           || _        || _        t          |d|j        |j        z            | _        |j        |j	        z  | _
        | j        dz  | _        |j        | _        || _        t          j        |j        |j        | j        z  |j                  | _        t          j        |j        |j	        | j        z  |j                  | _        t          j        |j        |j	        | j        z  |j                  | _        t          j        |j        | j        z  |j        d          | _        | j        j        0| j        j        }|| j        |z   dz
  |z  z  }|| j        z
  | _        d S d| _        d S )N)r   rT   rV   g      ࿩biasFr!   r   )r(   r)   updater*   r   getattrr.   r   rV   rT   rf   r^   attention_dropoutr   r,   r-   attention_biasq_projk_projv_projo_projpad_head_dim_to_multiple_ofhead_dim_padding)	r3   r*   r   r   r   rT   target_multipletarget_head_dimr5   s	           r6   r)   zMoonshineAttention.__init__   s    	.AZmnnooo"
F4F&Jd4dee$*$>&B\$\!}d*!'!9"i :T] JQWQf
 
 
 i :T] JQWQf
 
 
 i :T] JQWQf
 
 
 i :T] JFL^ejkkk ;2>"kEO-$-/2QTU2UZi1ijO$3dm$CD!!!$%D!!!r7   past_key_valuepast_key_values4.58new_nameversionNr8   position_embeddingsr]   cache_positionkey_value_statesr`   r9   c                 v   |j         d d         \  }}	|                     |                              ||	| j        j        | j                                      dd          }
|d u}|?|j                            | j	                  }|rd|j        | j	        <   |j
        }n|j        }||n|}|r3|r1|r/|j        | j	                 j        }|j        | j	                 j        }n|                     |                              |d| j        j        | j                                      dd          }|                     |                              |d| j        j        | j                                      dd          }|r$|"|                    ||| j	        d|i          \  }}|sB|\  }}t%          |
|||          \  }
}|&|||d}|                    ||| j	        |          \  }}t&          }| j        j        dk    rt*          | j        j                 }| j        o	|d u o|	dk    }| j        dk    rt0          j        j                            |
d| j        f          }
t0          j        j                            |d| j        f          }t0          j        j                            |d| j        f          } || |
|||f| j        sd	n| j        | j        |d
|\  }}| j        dk    r|dd | j         f         }|                    ||	d                                           }| !                    |          }||fS )NrI   r!   rG   Tr   )r   r   r   eagerr   rX   )r_   r^   r   .)"rP   r   viewr*   rT   rV   rh   
is_updatedgetr   cross_attention_cacheself_attention_cachelayerskeysvaluesr   r   r   r   rs   _attn_implementationr   r   r   r@   r,   ri   padre   r   r^   rR   rm   r   )r3   r8   r   r]   r   r   r   r`   bszq_lenquery_statesis_cross_attentionr   current_statesrn   ro   r   r   cache_kwargsattention_interfacer   rr   rp   s                          r6   r;   zMoonshineAttention.forward   s    #("-
U KK&&++C8WY]Yfggqqrsuvww 	 .T9&(377GGJ! G=A*4>:"1"G"1"F .>-I))} 	/ 	j 	(/?DJ*1$.AHLL N++c2t{>NN1a  N++c2t{>NN1a 
 " o&A+:+A+Adn?OQ_>`, ,(
L " 	*HC';L*VY[^'_'_$L**'*3.YY+:+A+Adnl, ,(
L )@;+w66"9$+:Z"[NK~'=K%!)	 1$$ 8.22<!TEZA[\\L,00aAV=WXXJ 8.22<!TEZA[\\L$7$7
%
  $}HCC$2HL
%
 
%
 
%
 
%
!\  1$$%c+Cd.C-C+C&CDK!))#ub99DDFFkk+..L((r7   )NNNNN)r=   r>   r?   __doc__r"   intboolr)   r    r@   rA   r   tupler
   
LongTensorr   r   r;   rB   rC   s   @r6   r   r      sv       GG#&#& #& 	#&
 !#& !#& #& #& #& #& #&J _%0A6RRR LP15+/5937U) U)|U) &eEL%,,F&GHU) !.	U)
 "%U) !!12U) #5<0U) -.U) 
u|Xel3XeEL>Q5RR	SU) U) U) SRU) U) U) U) U)r7   r   c                   |     e Zd ZU ej        ed<   ddef fdZ ej                    e	d                         Z
 xZS )MoonshineRotaryEmbeddinginv_freqNr*   c                    t                                                       t          |d          rSt          |j        t
                    r9|j                            d|j                            d                    | _        nd| _        |j        | _	        |j        | _
        || _        t          | j                 | _        |                     | j        |          \  }| _        |                     d|d           | j        | _        d S )Nrope_scaling	rope_typetypedefaultr   F)
persistent)r(   r)   hasattr
isinstancer   dictr   r   max_position_embeddingsmax_seq_len_cachedoriginal_max_seq_lenr*   r   rope_init_fnattention_scalingregister_bufferr   original_inv_freq)r3   r*   devicer   r5   s       r6   r)   z!MoonshineRotaryEmbedding.__init__-  s    6>** 	'z&:Mt/T/T 	'#044[&BUBYBYZ`BaBabbDNN&DN"("@$*$B!/?+/+<+<T[&+Q+Q($(ZeDDD!%r7   c                 X   | j         d d d d f                                                             |j        d         dd                              |j                  }|d d d d d f                                         }t          |j        j        t                    r|j        j        dk    r|j        j        nd}t          j
        |d          5  |                                |                                z                      dd          }t          j        ||fd	          }|                                | j        z  }|                                | j        z  }	d d d            n# 1 swxY w Y   |                    |j        
          |	                    |j        
          fS )Nr   rI   r!   mpscpuF)device_typeenabledrG   rJ   rc   )r   floatrQ   rP   rl   r   r   r   strr@   autocastrh   r~   r   r   r   rc   )
r3   rw   r   inv_freq_expandedposition_ids_expandedr   freqsembr   r   s
             r6   r;   z MoonshineRotaryEmbedding.forward>  s    !M$4-8>>@@GGHZ[\H]_acdeehhijiqrr ,QQQaaaZ 8 > > @ @'1!(-'E'Ek!(-[`J`J`ahmmfk^UCCC 	5 	5&,,..1F1L1L1N1NNYYZ[]^__E)UEN333C''))d44C''))d44C		5 	5 	5 	5 	5 	5 	5 	5 	5 	5 	5 	5 	5 	5 	5 vvAGv$$cff17f&;&;;;s   BE++E/2E/r&   )r=   r>   r?   r@   rA   __annotations__r"   r)   no_gradr   r;   rB   rC   s   @r6   r   r   *  s         l/ / / / / / / /" U]__< <  _< < < < <r7   r   c                   4    e Zd Zdedef fdZ eddd          	 	 	 	 	 	 dd
ej        de	ej                 de	ej
                 de	e         de	e         de	ej
                 de	eej        ej        f                  dee         dej        fd            Z xZS )MoonshineEncoderLayerr*   r   c                 Z   t                                                       |j        | _        t          ||d|j        |j                  | _        t          ||j                  | _	        t          j        |j        d          | _        t          j        |j        d          | _        d S )NFr*   r   r   r   rT   r   )r(   r)   r.   r   encoder_num_attention_headsencoder_num_key_value_heads	self_attnr$   encoder_hidden_actmlpr,   	LayerNorminput_layernormpost_attention_layernormr3   r*   r   r5   s      r6   r)   zMoonshineEncoderLayer.__init__O  s    !-+ & B & B
 
 
 'vv/HII!|F,>UKKK(*V5Ge(T(T(T%%%r7   r   r   r   r   NFr8   r]   r   	use_cacher   r   r`   r9   c                     |}	|                      |          } | j        d|||||||d|\  }}
|	|z   }|}	|                     |          }|                     |          }|	|z   }|S )Nr8   r]   r   r   r   r   r    )r   r   r   r   )r3   r8   r]   r   r   r   r   r   r`   residual_s              r6   r;   zMoonshineEncoderLayer.forward_  s     !,,];;)4> 	
')%+) 3	
 	
 	
 	
q !=0 !55mDD// =0r7   )NNNFNN)r=   r>   r?   r"   r   r)   r    r@   rA   r   r   r
   r   r   r   r   r;   rB   rC   s   @r6   r   r   N  s5       U U3 U U U U U U  _%0A6RRR 2637+/$)59KO | !. u/0	
 "% D> !!12 &eEL%,,F&GH +, 
   SR    r7   r   c            !           e Zd Zddedee         f fdZ eddd          	 	 	 	 	 	 	 	 	 	 dd
ej	        deej	                 deej	                 deej	                 deej
                 deej
                 dee         dee         deej
                 deeej	        ej	        f                  deeej	        ej	        f                  dee         deej        eeej        ej        f                  f         fd            Z xZS )MoonshineDecoderLayerNr*   r   c                    t                                                       |j        | _        t          ||d|j        |j                  | _        t          ||d|j        |j                  | _        t          ||j	                  | _
        t          j        |j        d          | _        t          j        |j        d          | _        t          j        |j        d          | _        d S )NTr   Fr   )r(   r)   r.   r   decoder_num_attention_headsdecoder_num_key_value_headsr   encoder_attnrE   decoder_hidden_actr   r,   r   r   r   final_layernormr   s      r6   r)   zMoonshineDecoderLayer.__init__  s    !-+ & B & B
 
 
 / & B & B
 
 
 'vv/HII!|F,>UKKK(*V5Ge(T(T(T%!|F,>UKKKr7   r   r   r   r   Fr8   r]   encoder_hidden_statesencoder_attention_maskr   encoder_position_idsr   r   r   encoder_position_embeddingsr`   r9   c                 F   |}|                      |          } | j        d||||||	|
d|\  }}||z   }|9|}|                     |          }|                     |||||          \  }}||z   }|}|                     |          }|                     |          }||z   }|S )Nr   )r8   r   r]   r   r   r   )r   r   r   r   r   r   )r3   r8   r]   r   r  r   r  r   r   r   r   r  r`   r   r   s                  r6   r;   zMoonshineDecoderLayer.forward  s      !,,];;)4> 	
')%+) 3	
 	
 	
 	
q !=0 ,$H 99-HHM#00+!65 /#  1    M1 %}4M ,,];;// =0r7   r&   )
NNNNNNFNNN)r=   r>   r?   r"   r   r   r)   r    r@   rA   r   r
   r   r   r   r   FloatTensorr;   rB   rC   s   @r6   r   r     s       L L L8C= L L L L L L0 _%0A6RRR 268<9=37;?+/$)59KOSW. .|. !..  (5	.
 !) 6. u/0. 'u'78. "%. D>. !!12. &eEL%,,F&GH. &.eEL%,4N.O%P. +,. 
u (51BEDU1U+V"WW	X. . . SR. . . . .r7   r   c                   P    e Zd ZU eed<   dZdZdZddgZdZ	dZ
dZdej        fdZd	S )
MoonshinePreTrainedModelr*   modelinput_valuesTr   r   input_lengthsc                     t          |dz
  dz  dz             }t          |dz
  dz  dz             }t          |dz
  dz  dz             }|S )zH
        Computes the output length of the convolutional layers
           @   r!      r   rG   )r   )r3   r
  output_conv1_lengthoutput_conv2_lengthoutput_conv3_lengths        r6    _get_feat_extract_output_lengthsz9MoonshinePreTrainedModel._get_feat_extract_output_lengths  sc     "=3#6""<q"@AA!#6#:a"?!"CDD!#6#:a"?!"CDD""r7   N)r=   r>   r?   r"   r   base_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modules_supports_flash_attn_supports_sdpa_can_compile_fullgraphr@   r   r  r   r7   r6   r  r    sn         $O&*#02IJN!#e>N # # # # # #r7   r  c            
            e Zd ZdZdZeedZdef fdZ	de
j        fdZde
j        fd	Ze	 ddej        deej                 dee         defd            Z xZS )MoonshineEncoderz
    Transformer encoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MoonshineEncoderLayer`]

    Args:
        config: MoonshineConfig
    r	  )
attentionsr8   r*   c                 `   t                                                     | _        j        }t	          j        d|ddd          | _        t	          j        |d|z  dd	          | _        t	          j        d|z  |dd	          | _        t	          j	        d|d
          | _
        t                    | _        t	          j        fdt          j                  D                       | _        t	          j        |d          | _        d| _        |                                  d S )Nr!   r  r  F)kernel_sizestrider   rG   r  r   )r  r  gh㈵>)
num_groupsnum_channelsepsr*   c                 0    g | ]}t          |          S r   )r   .0idxr*   s     r6   
<listcomp>z-MoonshineEncoder.__init__.<locals>.<listcomp>  $    cccC"63//cccr7   r   )r(   r)   r*   r.   r,   Conv1dconv1conv2conv3	GroupNorm	groupnormr   
rotary_emb
ModuleListrangeencoder_num_hidden_layersr   r   
layer_normgradient_checkpointing	post_init)r3   r*   	embed_dimr5   s    ` r6   r)   zMoonshineEncoder.__init__  s      &	Yq)ReTTT
Yy!i-QqQQQ
Yq9}iQqQQQ
PTUUU2&AAAmcccc5Aa;b;bccc
 
 ,yu===&+#r7   r9   c                     | j         S r&   r+  r3   s    r6   get_input_embeddingsz%MoonshineEncoder.get_input_embeddings  s
    zr7   r\   c                     || _         d S r&   r9  r3   r\   s     r6   set_input_embeddingsz%MoonshineEncoder.set_input_embeddings  s    


r7   Nr]   r`   c                    |                     d          }t          j                            |                     |                    }|                     |          }t          j                            |                     |                    }t          j                            |                     |                    }|	                    ddd          }|| 
                    |j        d                   }d}|ddd|f         dd|f         }| j        j        dk    r|d	k                                    r|nd}n;| j        j        d
k    rt          ||j                  }nt#          ||j                  }t%          j        d|j        d         |j                                       d          }|                     ||          }| j        D ]}	 |	|f|||d|}|                     |          }t1          |          S )a.  
        Args:
            input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
                Float values of the raw speech waveform. Raw speech waveform can be
                obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
                `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
                the soundfile library (`pip install soundfile`). To prepare the array into
                `input_values`, the [`AutoFeatureExtractor`] should be used for padding
                and conversion into a tensor of type `torch.FloatTensor`.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding indices in `input_values`. Mask values selected in `[0, 1]`:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
        r!   r   rG   NrI     .flash_attention_2rX   sdpar   )r]   r   r   )last_hidden_state)r|   r,   ri   tanhr+  r/  gelur,  r-  permuter  rP   r*   r   anyr   rc   r   r@   aranger   r0  r   r4  r   )
r3   r	  r]   r`   r8   mask_lendownsample_strider   r   encoder_layers
             r6   r;   zMoonshineEncoder.forward  s   , $--a00**4::l+C+CDD}55**4::m+D+DEE**4::m+D+DEE%--aA66 %<<^=QRT=UVVH *+C1D1D3D1D,DEc9H9nUN{/3FFF4Bc4I3N3N3P3P!ZVZ1V;;!D^UbUh!i!i!;NML_!`!`|A}':1'=mFZ[[[eefghh"oom\JJ![ 	 	M)M-)$7	 
  MM 66&+
 
 
 	
r7   r&   )r=   r>   r?   r   r  r   r   _can_record_outputsr"   r)   r,   Moduler;  r>  r   r@   r  r   rA   r   r   r   r;   rB   rC   s   @r6   r  r    s         %O(. 
      $bi    ")      268
 8
'8
 !.8
 +,	8

 
!8
 8
 8
 8
 8
 8
 8
 8
r7   r  c                       e Zd ZdZ eedd          e eedd          dZdef fdZ	e
	 	 	 	 	 	 	 	 	 ddeej                 d
eej                 deej                 dee         deej                 dee         deej                 deej                 deej                 dee         deeef         fd            Z xZS )MoonshineDecoder	input_idsr!   r   )index
layer_namer   )r  r8   cross_attentionsr*   c                    t                                                     j        | _        j        | _        t          j        j        j        | j                  | _        t          j	        fdt          j                  D                       | _        t          j        j        d          | _        t                    | _        d| _        |                                  d S )Nc                 0    g | ]}t          |          S r   )r   r%  s     r6   r(  z-MoonshineDecoder.__init__.<locals>.<listcomp>W  r)  r7   Fr   r#  )r(   r)   pad_token_idpadding_idx
vocab_sizer,   	Embeddingr.   embed_tokensr1  r2  decoder_num_hidden_layersr   r   normr   r0  r5  r6  r3   r*   r5   s    `r6   r)   zMoonshineDecoder.__init__P  s       !. +L):F<NPTP`aamcccc5Aa;b;bccc
 
 L!3%@@@	2&AAA&+# 	r7   Nr]   r   r   inputs_embedsr   r   r   r  r`   r9   c
                    |du |duz  rt          d          ||                     |          }|r8|6t          t          | j                  t          | j                            }|B||                                nd}t          j        |||j        d         z   |j	                  }||
                    d          }t          | j        |||||          }|}|                     ||          }|	|j        d         }d	}|	d
dd|f         d
d|f         }	| j        j        dk    r|	dk                                    r|	nd}	nS| j        j        dk    r"t          |	|j        |j        d                   }	n!t#          |	|j        |j        d                   }	| j        D ]} ||||f|	|||||d|
}|                     |          }t)          ||r|nd          S )a  
        encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
            Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
            of the decoder.
        encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding indices in `encoder_hidden_states`. Mask values selected in `[0, 1]`:
            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.
            [What are attention masks?](../glossary#attention-mask)
        Nz:You must specify exactly one of input_ids or inputs_embedsr#  r   r!   rC  )r*   input_embedsr]   r   r   r   rb   r@  .rA  rX   rB  )r  r   r   r   r   r   )rD  r   )
ValueErrorr[  r   r   r*   get_seq_lengthr@   rI  rP   r   r|   r   r0  r   rH  r   rc   r   r   r]  r   )r3   rQ  r]   r   r   r_  r   r   r   r  r`   past_seen_tokensrq   r8   r   rJ  rK  decoder_layers                     r6   r;   zMoonshineDecoder.forward`  s   0 -t";< 	[YZZZ  --i88M 	v01,dk2R2R2RT`hlhsTtTtTtuuO!CRC^==???de"\ "2]5H5K"KTaTh  N )33A66L(;&))+%
 
 
 &"oom\JJ!-,226H *%;CATATCTAT<T%UVY[d\d[dVd%e"{/3FFFDZ^aDaCfCfChCh)r)?)?nr&&1V;;)L*M,?ATUWAX* *&& *D*M,?ATUWAX* *& "[ 	 	M)M% (>) /#-$7   MM 		-008+/8BOOd
 
 
 	
r7   )	NNNNNNNNN)r=   r>   r?   r  r   r   r   rM  r"   r)   r   r   r@   r   rA   r
   r  r   r   r   r   r   r   r;   rB   rC   s   @r6   rP  rP  G  s       !O$n%7q[YYY.*N+=QSabbb          151537+/59$(59=A9=W
 W
E,-W
 !.W
 u/0	W

 "%W
   12W
 D>W
 !!12W
  ((9:W
 !) 6W
 +,W
 
u--	.W
 W
 W
 W
 W
 W
 W
 W
r7   rP  rP   	mask_probmask_length	min_masksc                 @   | \  }dk     rt          d          k    rt          d d d          t          j                            d                                          fd}|9|                                                    d                                          nfd	t          |          D             }t          j	        |ft          
          }g }	 |          }
|
dk    r|S |D ]} ||          }t          j                            t          j        |dz
  z
            |d          }t          |          dk    rdz
  }n|d         }t          j        |t          j        |
|z
  t          j        
          |z  g          }|	                    |           t          j        |	          }	t          j        |	dddddf         ||
f          }	|	                    ||
z            }	t          j                  ddddf         }t          j        |||
f                              ||
z            }|	|z   }	|	                                dz
  k    rdz
  |	|	dz
  k    <   t          j        ||	dd           |S )an  
    Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
    ASR](https://huggingface.co/papers/1904.08779). Note that this method is not optimized to run on TPU and should be run on
    CPU as part of the preprocessing during training.

    Args:
        shape: The shape for which to compute masks. This should be of a tuple of size 2 where
               the first element is the batch size and the second element is the length of the axis to span.
        mask_prob:  The percentage of the whole axis (between 0 and 1) which will be masked. The number of
                    independently generated mask spans of length `mask_length` is computed by
                    `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
                    actual percentage will be smaller.
        mask_length: size of the mask
        min_masks: minimum number of masked spans
        attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
                        each batch dimension.
    r!   z&`mask_length` has to be bigger than 0.zO`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: z and `sequence_length`: `c                     t          | z  z  z             }t          |          }|z  k    rz  }| dz
  z
  |k     rt          | dz
  z
  d          }|S )z;Given input length, compute how many spans should be maskedr!   r   )r   max)input_lengthnum_masked_spanepsilonrg  rf  rh  sequence_lengths     r6   compute_num_masked_spanz6_compute_mask_indices.<locals>.compute_num_masked_span  s~    i,6DwNOOoy99 [(?::-<O ;?+o==!,+/"BAFFOr7   NrI   c                     g | ]}S r   r   )r&  r   rp  s     r6   r(  z)_compute_mask_indices.<locals>.<listcomp>  s    999!o999r7   r   r   F)replace)rb  nprandomranditemdetachsumtolistr2  zerosr   choicerI  lenconcatenateonesint32appendarraybroadcast_torR   rl  put_along_axis)rP   rf  rg  r]   rh  
batch_sizerq  r
  spec_aug_maskspec_aug_mask_idxsmax_num_masked_spanrm  rn  spec_aug_mask_idxdummy_mask_idxoffsetsro  rp  s    `` `           @@r6   _compute_mask_indicesr    sP   0 #(JQABBB_$$:^i : :'6: : :
 
 	
 innQ$$&&G        $ % 	##B''..0009999uZ'8'8999  Hj/:$GGGM11/BBa% 5 511,?? I,,IlkAo677RW - 
 
  !!Q&& -q0NN.q1NN(;o(MUWU] ^ ^ ^ao op
 
 	!!"34444"455 111aaa:&5H+(V  ,33J@SVa@abb i$$T4]3Gog
4G'UVV^^'+5 G ,g5 /A"555GVYZGZ-!0CCD m%7B???r7   c                       e Zd Zdef fdZd Zd Zd Zd Z	 dde	j
        d	ee	j                 fd
Zee	 	 	 	 	 	 	 	 	 	 ddee	j
                 d	ee	j                 dee	j                 dee	j                 deeee	j
                                   deeeee	j
                 f                  deee	j
                          deee	j                          dee         dee	j                 dee         defd                        Z xZS )MoonshineModelr*   c                     t                                          |           t          |          | _        t	          |          | _        |                                  d S r&   )r(   r)   r  encoderrP  decoderr6  r^  s     r6   r)   zMoonshineModel.__init__4  sO       '//'//r7   c                     | j         j        S r&   r  r[  r:  s    r6   r;  z#MoonshineModel.get_input_embeddings<  s    |((r7   c                     || j         _        d S r&   r  r=  s     r6   r>  z#MoonshineModel.set_input_embeddings?  s    $)!!!r7   c                     | j         S r&   )r  r:  s    r6   get_encoderzMoonshineModel.get_encoderB  s
    |r7   c                 8    | j                                          dS )z
        Calling this function will disable the gradient computation for the Moonshine encoder so that its parameters will
        not be updated during training.
        N)r  _freeze_parametersr:  s    r6   freeze_encoderzMoonshineModel.freeze_encoderE  s    
 	'')))))r7   Ninput_featuresr]   c                 ~   t          | j        dd          s|S |                                \  }}}| j        j        dk    r| j        rt          ||f| j        j        | j        j        || j        j                  }t          j	        ||j
        t          j                  }|dddf                             d|d          }d||<   | j        j        dk    re| j        r^t          ||f| j        j        | j        j        | j        j                  }t          j	        ||j
        t          j                  }d||<   |S )	z
        Masks extracted features along time axis and/or along feature axis according to
        [SpecAugment](https://huggingface.co/papers/1904.08779).
        apply_spec_augmentTr   )rf  rg  r]   rh  )r   rc   NrI   )rf  rg  rh  )r   r*   sizemask_time_probre   r  mask_time_lengthmask_time_min_masksr@   tensorr   r   rQ   mask_feature_probmask_feature_lengthmask_feature_min_masks)r3   r  r]   r  r.   rp  mask_time_indicesmask_feature_indicess           r6   _mask_input_featuresz#MoonshineModel._mask_input_featuresL  s[    t{$8$?? 	"!! 4B3F3F3H3H0
K;%))dm) 5_-+4 K8-+9! ! ! !&->~G\didn o o o 1!!!T' : A A"kSU V V01N,-;(1,,,#8[)+7 K;+<	$ $ $  $)<0D^Mbjojt#u#u#u 34N/0r7   r	  decoder_input_idsdecoder_attention_maskencoder_outputsr   decoder_inputs_embedsdecoder_position_idsr   r   r`   r9   c                     | | j         |fd|i|} | j        d||||j        ||||	|
d	|}t          |j        |j        |j        |j        |j        |j        |j        |j                  S )a
  
        input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
            Float values of the raw speech waveform. Raw speech waveform can be
            obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
            `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
            the soundfile library (`pip install soundfile`). To prepare the array into
            `input_values`, the [`AutoFeatureExtractor`] should be used for padding
            and conversion into a tensor of type `torch.FloatTensor`.
        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
            Indices of positions of each input sequence tokens in the position embeddings.
            Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoFeatureExtractor, MoonshineModel
        >>> from datasets import load_dataset

        >>> model = MoonshineModel.from_pretrained("UsefulSensors/moonshine-tiny")
        >>> feature_extractor = AutoFeatureExtractor.from_pretrained("UsefulSensors/moonshine-tiny")
        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
        >>> input_values = inputs.input_values
        >>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
        >>> last_hidden_state = model(input_values, decoder_input_ids=decoder_input_ids).last_hidden_state
        >>> list(last_hidden_state.shape)
        [1, 2, 288]
        ```
        Nr]   )	rQ  r]   r  r   r   r_  r   r   r   )rD  r   decoder_hidden_statesdecoder_attentionsrT  encoder_last_hidden_stater   encoder_attentionsr   )r  r  rD  r   r   r8   r  rT  )r3   r	  r]   r  r  r  r   r  r  r   r   r`   decoder_outputss                r6   r;   zMoonshineModel.forwardw  s    \ "/;t|L/r/rYg/rkq/r/rOEQT\ F
'1#1"1"C+/-)F
 F
 F
 F
 "-?+;"1"?.9,=&5&G"1"?.9	
 	
 	
 		
r7   r&   )
NNNNNNNNNN)r=   r>   r?   r"   r)   r;  r>  r  r  r@   r  r   r   r  r   r   r   r   r   r   r   r   r   r;   rB   rC   s   @r6   r  r  2  s             ) ) )* * *  * * * 6:) ))) !!12) ) ) )V  59598<=AEIZ^DHBF$(59E
 E
u01E
 !!12E
 $E$45	E

 !))9 :E
 "%e.?(@"ABE
 "%(;U5CT=U(U"VWE
  (e.?(@AE
 'uU-='>?E
 D>E
 !!12E
 +,E
 
E
 E
 E
 ^ E
 E
 E
 E
 E
r7   r  rQ  rW  decoder_start_token_idc                     |                      | j                  }| ddddf                                         |ddddf<   ||dddf<   |t          d          |                    |dk    |           |S )z1
    Shift input ids one token to the right.
    NrI   r!   r   z1self.model.config.pad_token_id has to be defined.i)	new_zerosrP   clonerb  masked_fill_)rQ  rW  r  shifted_input_idss       r6   shift_tokens_rightr    s     "++IO<<(CRC06688aaae4aaadLMMM""#4#<lKKKr7   zj
    The Moonshine Model with a language modeling head. Can be used for automatic speech recognition.
    )custom_introc                       e Zd ZdgZdef fdZd Zd Zd Zd Z	de
j        fd	Zee	 	 	 	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 deej                 deeeej                                   deeeeej                 f                  deeej                          deeej                          dee         deej                 deej                 dee         defd                        Z xZS )!MoonshineForConditionalGenerationzproj_out.weightr*   c                     t                                          |           t          |          | _        t	          j        |j        |j        d          | _        | 	                                 d S )NFr   )
r(   r)   r  r  r,   r-   r.   rY  proj_outr6  r^  s     r6   r)   z*MoonshineForConditionalGeneration.__init__  s`       #F++
	&"4f6GeTTT 	r7   c                 4    | j                                         S r&   )r  r  r:  s    r6   r  z-MoonshineForConditionalGeneration.get_encoder      z%%'''r7   c                 4    | j                                         S r&   )r  get_decoderr:  s    r6   r  z-MoonshineForConditionalGeneration.get_decoder  r  r7   c                     | j         S r&   r  r:  s    r6   get_output_embeddingsz7MoonshineForConditionalGeneration.get_output_embeddings  s
    }r7   c                     || _         d S r&   r  )r3   new_embeddingss     r6   set_output_embeddingsz7MoonshineForConditionalGeneration.set_output_embeddings  s    &r7   r9   c                 4    | j                                         S r&   )r  r;  r:  s    r6   r;  z6MoonshineForConditionalGeneration.get_input_embeddings  s    z..000r7   Nr	  r]   r  r  r  r   r  r  r   r   labelsr`   c                 ~   |)|'|%t          || j        j        | j        j                  } | j        |f||||||||	|
d	|}|                     |j                  }d}|"|                     ||| j        j                  }t          |||j
        |j        |j        |j        |j        |j        |j        	  	        S )a0  
        input_values (`torch.FloatTensor` of shape `(batch_size, audio_length)`):
            Float values of the raw speech waveform. Raw speech waveform can be
            obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]`, a
            `numpy.ndarray` or a `torch.Tensor`, *e.g.* via the torchcodec library (`pip install torchcodec`) or
            the soundfile library (`pip install soundfile`). To prepare the array into
            `input_values`, the [`AutoFeatureExtractor`] should be used for padding
            and conversion into a tensor of type `torch.FloatTensor`.
        decoder_position_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`):
            Indices of positions of each input sequence tokens in the position embeddings.
            Used to calculate the position embeddings up to `config.decoder_config.max_position_embeddings`

        Example:

        ```python
        >>> import torch
        >>> from transformers import AutoProcessor, MoonshineForConditionalGeneration
        >>> from datasets import load_dataset

        >>> processor = AutoProcessor.from_pretrained("UsefulSensors/moonshine-tiny")
        >>> model = MoonshineForConditionalGeneration.from_pretrained("UsefulSensors/moonshine-tiny")

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

        >>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
        >>> input_values = inputs.input_values

        >>> generated_ids = model.generate(input_values, max_new_tokens=100)

        >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        >>> transcription
        'Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
        ```N)	r]   r  r  r  r   r  r  r   r   )logitsr  rY  )	lossr  r   r  r  rT  r  r   r  )r  r*   rW  r  r  r  rD  loss_functionrY  r   r   r  r  rT  r  r   r  )r3   r	  r]   r  r  r  r   r  r  r   r   r  r`   outputsr  r  s                   r6   r;   z)MoonshineForConditionalGeneration.forward  s   f  (-B-J$6DK4dk6X% %! '1dj'
)/+#9+"7!5)'
 '
 '
 '
 w899%%VFt{Oe%ffD#3")"?&9$5&-&G")"?&9

 

 

 
	
r7   )NNNNNNNNNNN)r=   r>   r?   _tied_weights_keysr"   r)   r  r  r  r  r,   rN  r;  r   r   r   r@   r  r   r   r   r   r   r   r   r   r;   rB   rC   s   @r6   r  r    s        ,,      ( ( (( ( (  ' ' '1bi 1 1 1 1  59598<=AEIZ^DHBF$(59-1T
 T
u01T
 !!12T
 $E$45	T

 !))9 :T
 "%e.?(@"ABT
 "%(;U5CT=U(U"VWT
  (e.?(@AT
 'uU-='>?T
 D>T
 !!12T
 )*T
 +,T
 
T
 T
 T
 ^ T
 T
 T
 T
 T
r7   r  )r  r  r  )rX   )Nr!   )Nr   )Ltypingr   r   r   numpyrt  r@   torch.nnr,   transformers.utils.genericr   r   activationsr	   cache_utilsr
   r   r   
generationr   masking_utilsr   modeling_attn_mask_utilsr   r   modeling_flash_attention_utilsr   modeling_layersr   modeling_outputsr   r   r   r   r   modeling_rope_utilsr   r   modeling_utilsr   r   processing_utilsr   utilsr   r   r   utils.deprecationr    configuration_moonshiner"   rN  r$   rE   rA   r   rW   r   rs   rz   r   r   r   r   r   r  r  rP  r   r   ndarrayr  r  r  r  __all__r   r7   r6   <module>r     s  * - , , , , , , , , ,            I I I I I I I I ! ! ! ! ! ! C C C C C C C C C C ) ) ) ) ) ) / / / / / / g g g g g g g g B B B B B B 9 9 9 9 9 9              L K K K K K K K F F F F F F F F & & & & & & I I I I I I I I I I 0 0 0 0 0 0 4 4 4 4 4 4    ")       ")    	UU\ 	U# 	U%, 	U 	U 	U 	U& % %I%<% 
% <	%
 U\*% % % '(% % % %46 6 6' ' ' 'T~) ~) ~) ~) ~) ~) ~) ~)B!< !< !< !< !<ry !< !< !<H1 1 1 1 16 1 1 1hH H H H H6 H H HV # # # # # # # #._
 _
 _
 _
 _
/ _
 _
 _
D p
 p
 p
 p
 p
/ p
 p
 p
n 26t tc?tt t U-.	t
 t Zt t t tn K
 K
 K
 K
 K
- K
 K
 K
\%, c [^        
p
 p
 p
 p
 p
(@/ p
 p
 
p
f ^
]
]r7   