
     `i                        d Z ddlZddlZddlmZ ddlmZ ddlm	Z	m
Z
 ddlZddlmZ ddlmZmZmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZmZm Z m!Z!m"Z" ddl#m$Z$  e!j%        e&          Z'dZ(dZ)e G d de                      Z*e G d de                      Z+e G d de                      Z,dKdZ-dLdZ.dMdZ/ G d d ej0                  Z1 G d! d"ej0                  Z2 G d# d$ej0                  Z3 G d% d&ej0                  Z4 G d' d(ej0                  Z5 G d) d*ej0                  Z6 G d+ d,ej0                  Z7 G d- d.ej0                  Z8 G d/ d0ej0                  Z9 G d1 d2e          Z: G d3 d4ej0                  Z; G d5 d6e          Z<d7Z=d8Z> ed9e=           G d: d;e<                      Z? G d< d=ej0                  Z@ ed>e=           G d? d@e<                      ZA G dA dBej0                  ZB G dC dDej0                  ZC G dE dFej0                  ZD edGe=           G dH dIe<                      ZEg dJZFdS )NzPyTorch TVLT model.    N)deepcopy)	dataclass)OptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputSequenceClassifierOutput)PreTrainedModel) find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputadd_start_docstrings%add_start_docstrings_to_model_forwardloggingreplace_return_docstrings   )
TvltConfigr   zZinengTang/tvlt-basec                   x   e Zd ZU dZdZeej                 ed<   dZ	eej                 ed<   dZ
eej                 ed<   dZeej                 ed<   dZeej                 ed<   dZeej                 ed<   dZeej                 ed	<   dZeeej        d
f                  ed<   dZeeej        d
f                  ed<   dS )TvltModelOutputa  
    Class for TvltModel's outputs, with potential hidden states and attentions.

    Args:
        last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
            Sequence of hidden-states at the output of the last layer of the model.
        last_pixel_hidden_state (`torch.FloatTensor` of shape `(batch_size, pixel_sequence_length, hidden_size)`):
            Pixel sequence of hidden-states at the output of the last layer of the model.
        last_audio_hidden_state (`torch.FloatTensor` of shape `(batch_size, audio_sequence_length, hidden_size)`):
            Audio sequence of hidden-states at the output of the last layer of the model.
        pixel_label_masks (`torch.FloatTensor` of shape `(batch_size, pixel_patch_length)`):
            Tensor indicating which pixel patches are masked (1) and which are not (0).
        audio_label_masks (`torch.FloatTensor` of shape `(batch_size, audio_patch_length)`):
            Tensor indicating which audio patches are masked (1) and which are not (0).
        pixel_ids_restore (`torch.LongTensor` of shape `(batch_size, pixel_patch_length)`):
            Tensor containing the ids permutation of pixel masking.
        audio_ids_restore (`torch.LongTensor` of shape `(batch_size, audio_patch_length)`):
            Tensor containing the ids permutation of audio masking.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlast_hidden_statelast_pixel_hidden_statelast_audio_hidden_statepixel_label_masksaudio_label_maskspixel_ids_restoreaudio_ids_restore.hidden_states
attentions)__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__r   r   r   
LongTensorr    r!   r"   r#   tupler$        /home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/transformers/models/deprecated/tvlt/modeling_tvlt.pyr   r   0   s         8 6:x 12999;?Xe&78???;?Xe&78???48x 0188848x 0188848x 0188848x 01888=AM8E%"3S"89:AAA:>Ju0#567>>>>>r/   r   c                       e Zd ZU dZdZeej                 ed<   dZ	ee
ej        df                  ed<   dZee
ej        df                  ed<   dS )TvltDecoderOutputaM  
    Class for TvltDecoder's outputs, with potential hidden states and attentions.

    Args:
        logits (`torch.FloatTensor` of shape `(batch_size, patch_size ** 2 * num_channels)`):
            Pixel reconstruction logits.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlogits.r#   r$   )r%   r&   r'   r(   r3   r   r)   r*   r+   r#   r-   r$   r.   r/   r0   r2   r2   Y   sz            +/FHU&'...=AM8E%"3S"89:AAA:>Ju0#567>>>>>r/   r2   c                      e Zd ZU dZdZeej                 ed<   dZ	eej                 ed<   dZ
eej                 ed<   dZeej                 ed<   dZeeej        df                  ed<   dZeeej        df                  ed	<   dS )
TvltForPreTrainingOutputa
  
    Class for TvltForPreTraining's outputs, with potential hidden states and attentions.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`):
            Pixel reconstruction loss.
        matching_logits (`torch.FloatTensor` of shape `(batch_size, 1)`):
            Matching objective logits.
        pixel_logits (`torch.FloatTensor` of shape
            `(batch_size, pixel_patch_length, image_patch_size ** 3 * pixel_num_channels)`): Pixel reconstruction
            logits.
        audio_logits (`torch.FloatTensor` of shape
            `(batch_size, audio_patch_length, image_patch_size[0] * image_patch_size[1])`): Audio reconstruction
            logits.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
            shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer
            plus the initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
            the self-attention heads.
    Nlossmatching_logitspixel_logitsaudio_logits.r#   r$   )r%   r&   r'   r(   r6   r   r)   r*   r+   r7   r8   r9   r#   r-   r$   r.   r/   r0   r5   r5   p   s          0 )-D(5$
%,,,37OXe/077704L(5,-44404L(5,-444=AM8E%"3S"89:AAA:>Ju0#567>>>>>r/   r5         ?c                     | j         dd         \  }}t          j        ||f| j                  }t	          |d|z
  z            }||fS )!Generate noise for audio masking.N   devicer   )shaper)   randr?   int)pixel_values
pixel_mask
mask_ratio
batch_sizeseq_lennoiselen_keeps          r0   generate_pixel_mask_noiserJ      sV     ',RaR0JJ
G,\5HIIIE7a*n-..H(?r/   patch-level   c                 d   | j         dd         \  }}|dk    r^||z  }t          j        ||| j                                      d                              dd|                              ||          }n"|dk    rt          j        ||| j                  }t          |d|z
  z            }	||	fS )r<   Nr=   zframe-levelr>   r   rK   )r@   r)   rA   r?   	unsqueezerepeatviewrB   )
audio_values
audio_maskrE   	mask_typefreq_lenrF   rG   num_time_patchesrH   rI   s
             r0   generate_audio_mask_noiserW      s     ',RaR0JM!!"h.Jz#3L<OPPPYr]]VAq(##T*g&&	 	 
m	#	#
:w|7JKKK7a*n-..H(?r/   c           	         | j         \  }}}t          j        |d          }t          j        |d          }|ddd|f         }	t          j        | d|	                    d                              dd|                    }
t          j        ||g| j                  }d|ddd|f<   t          j        |d|          }|||z  }t          j        |d|	          }|
|||fS )z
    Perform random masking by per-sample shuffling on frame-level. Per-sample shuffling is done by argsort random
    noise. sequence: [batch_size, seq_len, hidden_dim], sequence
    r   dimNrN   rZ   indexr>   r   )r@   r)   argsortgatherrO   rP   onesr?   )sequencerH   rI   attention_masksrF   rG   
hidden_dimids_shuffleids_restoreids_keepsequence_maskedlabel_maskss               r0   random_maskingrh      s    '/n#J -1---K-333K 111ixi<(Hl8(:L:LR:P:P:W:WXY[\^h:i:ijjjO *j'28?KKKK !K9H9,{EEEK"&,AXNNNO[+EEr/   c                   *     e Zd ZdZ fdZddZ xZS )TvltPixelEmbeddings,Construct the patch and position embeddings.c                    t                                                       t          |          | _        | j        j        | _        t          j        t          j        dd|j	                            | _
        t          j        t          j        d|j        |j	                            | _        t          j        t          j        d| j        |j	                            | _        || _        d S Nr   )super__init__TvltPixelPatchEmbeddingspatch_embeddingsnum_patches_per_imager   	Parameterr)   zeroshidden_sizetype_embed_v
num_framestemporal_embedpos_embed_vconfigselfrz   	__class__s     r0   ro   zTvltPixelEmbeddings.__init__   s     8 @ @%)%:%P"LQ6;M)N)NOO l5;q&:KVM_+`+`aa<At7QSYSe(f(fggr/   Nc                     |j         \  }}}}}|                     |          }|| j                            d|d          z  }|t	          j        | j        d d d |f         | j        d          z  }|| j        z  }||fS Nr   rY   )	r@   rq   ry   rP   r)   repeat_interleaverx   rr   rv   )	r|   rC   ra   rF   rw   num_channelsheightwidth
embeddingss	            r0   forwardzTvltPixelEmbeddings.forward   s    >J>P;
Jfe**<88
d&--aQ???
e-d.A!!![j[..QSWSmstuuuu
d''
?**r/   Nr%   r&   r'   r(   ro   r   __classcell__r}   s   @r0   rj   rj      sR        66
 
 
 
 
	+ 	+ 	+ 	+ 	+ 	+ 	+ 	+r/   rj   c                   *     e Zd ZdZ fdZddZ xZS )TvltAudioEmbeddingsrk   c                 X   t                                                       t          |          | _        | j        j        | _        t          j        t          j        dd|j	                            | _
        |j        |j        d         z  | _        t          j        t          j        d| j        | j        z  |j	                            | _        t          j        t          j        d| j        |j	                            | _        |j        |j        d         z  | _        || _        d S rm   )rn   ro   TvltAudioPatchEmbeddingsrq   num_patchesr   rs   r)   rt   ru   type_embed_afrequency_lengthaudio_patch_sizenum_freq_patchespos_embed_a
freq_embedrz   r{   s     r0   ro   zTvltAudioEmbeddings.__init__   s     8 @ @0<LQ6;M)N)NOO & 76;RST;U U<At7G4K`7`bhbt(u(uvv,u{1d6KVM_'`'`aa & 76;RST;U Ur/   Nc                     |                      |          }|                    d          | j        z  }|| j                            d|d          z  }|t          j        | j        d d d |f         | j        d          z  }|| j        z  }||fS r   )	rq   sizer   r   rP   r)   r   r   r   )r|   rR   ra   r   rV   s        r0   r   zTvltAudioEmbeddings.forward   s    **<88
%??1--1FFdo,,Q0@!DDD
e-d.>qqqBSCSBS?S.TVZVkqrssss
d''
?**r/   r   r   r   s   @r0   r   r      sR        66    	+ 	+ 	+ 	+ 	+ 	+ 	+ 	+r/   r   c                   F     e Zd ZdZ fdZdej        dej        fdZ xZS )rp   z
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    t                                                       |j        |j        }}|j        |j        }}t          |t          j        j	                  r|n||f}t          |t          j        j	                  r|n||f}|d         |d         z  |d         |d         z  z  }|| _        || _
        || _        || _        || _        t          j        ||||          | _        d S Nr   r   )kernel_sizestride)rn   ro   
image_sizeimage_patch_sizenum_image_channelsru   
isinstancecollectionsabcIterable
patch_sizer   rr   r   Conv2d
projection)r|   rz   r   r   r   ru   rr   r}   s          r0   ro   z!TvltPixelPatchEmbeddings.__init__	  s    !'!2F4KJ
$*$=v?Qk#-j+/:R#S#SqZZZdfpYq
#-j+/:R#S#SqZZZdfpYq
!+A*Q-!?JqMU_`aUbDb c$$(%:"&)L+:^hiiir/   rC   returnc                    |j         \  }}}}}|| j        k    rt          d          || j        d         k    s|| j        d         k    r2t          d| d| d| j        d          d| j        d          d	          |                    ||z  |||          }|                     |                              d                              dd          }|                    ||| j        z  | j	                  }|S )	NeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*) doesn't match model ().r=   )
r@   r   
ValueErrorr   reshaper   flatten	transposerr   ru   )r|   rC   rF   rw   r   r   r   r   s           r0   r   z TvltPixelPatchEmbeddings.forward  s!   >J>P;
Jfe4,,,w   T_Q'''5DOA4F+F+FwVwwewwDO\]L^wwaeapqraswww   $++J,C\SY[`aa__\22::1==GG1MM
''
JA[4[]a]mnn
r/   	r%   r&   r'   r(   ro   r)   Tensorr   r   r   s   @r0   rp   rp     sm         j j j j j EL U\        r/   rp   c                   F     e Zd ZdZ fdZdej        dej        fdZ xZS )r   z
    This class turns `audio_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    t                                                       |j        |j        |j        }}}|j        |j        }}||f}t          |t          j	        j
                  r|n||f}|d         |d         z  |d         |d         z  z  }|d         |d         z  |d         |d         z  f}	|| _        || _        || _        || _        |	| _        t!          j        ||||          | _        d S r   )rn   ro   spectrogram_lengthr   r   num_audio_channelsru   r   r   r   r   spectrogram_sizer   r   r   patch_shaper   r   r   )r|   rz   r   r   r   r   ru   r   r   r   r}   s             r0   ro   z!TvltAudioPatchEmbeddings.__init__2  s   %## /9,
 %+$=v?Qk.0@A#-j+/:R#S#SqZZZdfpYq
'*jm;@PQR@SWabcWd@de'*jm;=Ma=PT^_`Ta=ab 0$(&&)L+:^hiiir/   rR   r   c                 r   |j         \  }}}}|| j        k    rt          d          || j        d         k    s|| j        d         k    r2t          d| d| d| j        d          d| j        d          d	          |                     |                              d                              dd          }|S )	Nr   r   r   zInput audio size (r   r   r   r=   )r@   r   r   r   r   r   r   )r|   rR   rF   r   r   r   r   s          r0   r   z TvltAudioPatchEmbeddings.forwardG  s    2>2D/
L&%4,,,w   D)!,,,9Nq9Q0Q0QMV M Me M M*1-M M040Ea0HM M M   __\22::1==GG1MM
r/   r   r   s   @r0   r   r   +  sm         j j j j j*EL U\        r/   r   c                   ,     e Zd Z fdZd ZddZ xZS )TvltSelfAttentionc                    t                                                       |j        |j        z  dk    r0t	          |d          s t          d|j         d|j         d          |j        | _        t          |j        |j        z            | _        | j        | j        z  | _        t          j
        |j        | j        |j                  | _        t          j
        |j        | j        |j                  | _        t          j
        |j        | j        |j                  | _        t          j        |j                  | _        d S )Nr   embedding_sizezThe hidden size z4 is not a multiple of the number of attention heads .bias)rn   ro   ru   num_attention_headshasattrr   rB   attention_head_sizeall_head_sizer   Linearqkv_biasquerykeyvalueDropoutattention_probs_dropout_probdropoutr{   s     r0   ro   zTvltSelfAttention.__init__X  s.    ::a??PVXhHiHi?76#5 7 737 7 7  
 $*#= #&v'9F<V'V#W#W !58PPYv143EFO\\\
9V/1C&/ZZZYv143EFO\\\
z&"EFFr/   c                     |                                 d d         | j        | j        fz   } |j        | }|                    dddd          S )NrN   r   r=   r      )r   r   r   rQ   permute)r|   xnew_x_shapes      r0   transpose_for_scoresz&TvltSelfAttention.transpose_for_scoresj  sM    ffhhssmt'?AY&ZZAFK yyAq!$$$r/   NFc                    |                      |          }|                     |                     |                    }|                     |                     |                    }|                     |          }t	          j        ||                    dd                    }	|	t          j        | j	                  z  }	||	|z   }	 t          j        d          |	          }
|                     |
          }
||
|z  }
t	          j        |
|          }|                    dddd                                          }|                                d d         | j        fz   } |j        | }|r||
fn|f}|S )NrN   rY   r   r=   r   r   )r   r   r   r   r)   matmulr   mathsqrtr   r   Softmaxr   r   
contiguousr   r   rQ   )r|   r#   attention_mask	head_maskoutput_attentionsmixed_query_layer	key_layervalue_layerquery_layerattention_scoresattention_probscontext_layernew_context_layer_shapeoutputss                 r0   r   zTvltSelfAttention.forwardo  s    JJ}55--dhh}.E.EFF	//

=0I0IJJ//0ABB !<Y5H5HR5P5PQQ+di8P.Q.QQ%/.@ -"*,,,-=>> ,,77  -	9O_kBB%--aAq99DDFF"/"4"4"6"6ss";t?Q>S"S**,CD6G]=/22mM]r/   NNF)r%   r&   r'   ro   r   r   r   r   s   @r0   r   r   W  s`        G G G G G$% % %
! ! ! ! ! ! ! !r/   r   c                   ^     e Zd ZdZdeddf fdZdej        dej        dej        fdZ xZ	S )	TvltSelfOutputz
    The residual connection is defined in TvltLayer instead of here (as is the case with other models), due to the
    layernorm applied before each block.
    rz   r   Nc                     t                                                       t          j        |j        |j                  | _        t          j        |j                  | _        d S r   )	rn   ro   r   r   ru   denser   hidden_dropout_probr   r{   s     r0   ro   zTvltSelfOutput.__init__  sJ    Yv163EFF
z&"<==r/   r#   input_tensorc                 Z    |                      |          }|                     |          }|S r   r   r   r|   r#   r   s      r0   r   zTvltSelfOutput.forward  s*    

=11]33r/   )
r%   r&   r'   r(   r   ro   r)   r   r   r   r   s   @r0   r   r     s         
>z >d > > > > > >
U\  RWR^        r/   r   c                   ,     e Zd Z fdZd ZddZ xZS )TvltAttentionc                     t                                                       t          |          | _        t	          |          | _        t                      | _        d S r   )rn   ro   r   	attentionr   outputsetpruned_headsr{   s     r0   ro   zTvltAttention.__init__  sI    *622$V,,EEr/   c                    t          |          dk    rd S t          || j        j        | j        j        | j                  \  }}t          | j        j        |          | j        _        t          | j        j        |          | j        _        t          | j        j	        |          | j        _	        t          | j
        j        |d          | j
        _        | j        j        t          |          z
  | j        _        | j        j        | j        j        z  | j        _        | j                            |          | _        d S )Nr   r   rY   )lenr   r   r   r   r   r   r   r   r   r   r   r   union)r|   headsr\   s      r0   prune_headszTvltAttention.prune_heads  s   u::??F74>5t~7Y[_[l
 
u
  2$.2FNN/0BEJJ1$.2FNN.t{/@%QOOO .2^-ORUV[R\R\-\*'+~'IDNLn'n$ -33E::r/   NFc                     |                      ||||          }|                     |d         |          }|f|dd          z   }|S )Nr   r   )r   r   )r|   r#   r   r   r   self_outputsattention_outputr   s           r0   r   zTvltAttention.forward  sM    ~~m^YPabb;;|AFF#%QRR(88r/   r   )r%   r&   r'   ro   r   r   r   r   s   @r0   r   r     s[        " " " " "; ; ;$       r/   r   c                   L     e Zd Zdeddf fdZdej        dej        fdZ xZS )TvltIntermediaterz   r   Nc                    t                                                       t          j        |j        |j                  | _        t          |j        t                    rt          |j                 | _        d S |j        | _        d S r   )rn   ro   r   r   ru   intermediate_sizer   r   
hidden_actstrr   intermediate_act_fnr{   s     r0   ro   zTvltIntermediate.__init__  sn    Yv163KLL
f'-- 	9'-f.?'@D$$$'-'8D$$$r/   r#   c                 Z    |                      |          }|                     |          }|S r   )r   r  r|   r#   s     r0   r   zTvltIntermediate.forward  s,    

=1100??r/   	r%   r&   r'   r   ro   r)   r   r   r   r   s   @r0   r   r     sq        9z 9d 9 9 9 9 9 9U\ el        r/   r   c                   Z     e Zd Zdeddf fdZdej        dej        dej        fdZ xZS )
TvltOutputrz   r   Nc                     t                                                       t          j        |j        |j                  | _        t          j        |j                  | _	        d S r   )
rn   ro   r   r   r  ru   r   r   r   r   r{   s     r0   ro   zTvltOutput.__init__  sJ    Yv79KLL
z&"<==r/   r#   r   c                 d    |                      |          }|                     |          }||z   }|S r   r   r   s      r0   r   zTvltOutput.forward  s4    

=11]33%4r/   r  r   s   @r0   r	  r	    s|        >z >d > > > > > >
U\  RWR^        r/   r	  c                   *     e Zd ZdZ fdZddZ xZS )	TvltLayerz?This corresponds to the Block class in the timm implementation.c                 z   t                                                       |j        | _        d| _        t	          |          | _        t          |          | _        t          |          | _	        t          j        |j        |j                  | _        t          j        |j        |j                  | _        d S Nr   eps)rn   ro   chunk_size_feed_forwardseq_len_dimr   r   r   intermediater	  r   r   	LayerNormru   layer_norm_epslayernorm_beforelayernorm_afterr{   s     r0   ro   zTvltLayer.__init__  s    '-'E$&v..,V44 (( "V-?VEZ [ [ [!|F,>FDYZZZr/   NFc                 H   |                      |                     |          |||          }|d         }|dd          }||                    |j                  z   }|                     |          }|                     |          }|                     ||          }|f|z   }|S )Nr   r   r   )r   r  tor?   r  r  r   )	r|   r#   r   r   r   self_attention_outputsr   r   layer_outputs	            r0   r   zTvltLayer.forward  s    !%!!-00/	 "0 "
 "
 2!4(, )=+;+;<L<S+T+TT ++M::((66 {{<??/G+r/   r   r   r   s   @r0   r  r    sW        II[ [ [ [ [       r/   r  c                   0     e Zd Z fdZ	 	 	 	 	 ddZ xZS )TvltEncoderc                     t                                                       | _        t          j        fdt          j                  D                       | _        d| _        d S )Nc                 .    g | ]}t                    S r.   r  ).0_rz   s     r0   
<listcomp>z(TvltEncoder.__init__.<locals>.<listcomp>  s!    #_#_#_!If$5$5#_#_#_r/   F)	rn   ro   rz   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingr{   s    `r0   ro   zTvltEncoder.__init__  s`    ]#_#_#_#_uVE]?^?^#_#_#_``
&+###r/   NFTc                 .   |rdnd }|rdnd }t          | j                  D ]=\  }	}
|r||fz   }|||	         nd } |
||||          }|d         }|r||d         fz   }>|r||fz   }|st          d |||fD                       S t          |||          S )Nr.   r   r   c              3      K   | ]}||V  	d S r   r.   r#  vs     r0   	<genexpr>z&TvltEncoder.forward.<locals>.<genexpr>0  s(      mmq_`_l_l_l_l_lmmr/   )r   r#   r$   )	enumerater)  r-   r   )r|   r#   r   r   r   output_hidden_statesreturn_dictall_hidden_statesall_self_attentionsilayer_modulelayer_head_masklayer_outputss                r0   r   zTvltEncoder.forward  s    #7@BBD$5?bb4(44 	P 	POA|# I$58H$H!.7.CillO(LYjkkM)!,M  P&9]1=M<O&O# 	E 1]4D D 	nmm]4EGZ$[mmmmmm++*
 
 
 	
r/   )NNFFTr%   r&   r'   ro   r   r   r   s   @r0   r  r    s]        , , , , , ""
 "
 "
 "
 "
 "
 "
 "
r/   r  c                   0    e Zd ZU dZeed<   dZdZdZd Z	dS )TvltPreTrainedModelz
    An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
    models.
    rz   tvltrC   Tc                    t          |t          j        t          j        f          rT|j        j                            d| j        j                   |j	         |j	        j        
                                 dS dS t          |t          j                  r?|j	        j        
                                 |j        j                            d           dS dS )zInitialize the weights        )meanstdNg      ?)r   r   r   r   weightdatanormal_rz   initializer_ranger   zero_r  fill_)r|   modules     r0   _init_weightsz!TvltPreTrainedModel._init_weightsC  s    fry")455 	* M&&CT[5R&SSS{& &&((((( '&-- 	*K""$$$M$$S)))))	* 	*r/   N)
r%   r&   r'   r(   r   r+   base_model_prefixmain_input_namesupports_gradient_checkpointingrH  r.   r/   r0   r;  r;  8  sN          
 $O&*#
* 
* 
* 
* 
*r/   r;  aF  
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.

    Parameters:
        config ([`TvltConfig`]): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
a	  
    Args:
        pixel_values (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values. Pixel values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        audio_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Audio values. Audio values can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        pixel_mask (`torch.FloatTensor` of shape `(batch_size, num_pixel_patches)`):
            Pixel masks. Pixel masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        audio_mask (`torch.FloatTensor` of shape `(batch_size, num_audio_patches)`):
            Audio masks. Audio masks can be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for
            details.

        pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Pixel values mixed can
            be obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details.

        pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel masks of pixel_values_mixed. Pixel masks mixed can be obtained using [`TvltProcessor`]. See
            [`TvltProcessor.__call__`] for details.

        mask_pixel (`bool`, *optional*):
            Whether to mask pixel for MAE tasks. Only set to True in TvltForPreTraining.

        mask_audio (`bool`, *optional*):
            Whether to mask audio for MAE tasks. Only set to True in TvltForPreTraining.

        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.

        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.

        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
z^The bare TVLT Model transformer outputting raw hidden-states without any specific head on top.c                   <    e Zd Z fdZd Zd Z ee           ee	e
          	 	 	 	 	 	 	 ddej        dej        d	eej                 d
eej                 dededee         dee         dee         deeej                 e	f         fd                        Z xZS )	TvltModelc                    t                                          |           || _        t          |          | _        t          |          | _        t          |          | _        t          j
        t          j        dd|j                            | _        |j        rd | _        n%t          j        |j        |j                  | _        |                                  d S r  )rn   ro   rz   rj   pixel_embeddingsr   audio_embeddingsr  encoderr   rs   r)   rt   ru   cls_embeddinguse_mean_pooling	layernormr  r  	post_initr{   s     r0   ro   zTvltModel.__init__  s        3F ; ; 3F ; ;"6**\%+aF<N*O*OPP" 	Y!DNN\&*<&BWXXXDN 	r/   c                 2    | j         j        | j        j        fS r   )rO  rq   rP  )r|   s    r0   get_input_embeddingszTvltModel.get_input_embeddings  s    $5t7L7]]]r/   c                     |                                 D ]/\  }}| j        j        |         j                            |           0dS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsrQ  r)  r   r   )r|   heads_to_pruner)  r   s       r0   _prune_headszTvltModel._prune_heads  sU    
 +0022 	C 	CLE5Lu%/;;EBBBB	C 	Cr/   output_typeconfig_classNFrC   rR   rD   rS   
mask_pixel
mask_audior   r1  r2  r   c
                    ||n| j         j        }||n| j         j        }|	|	n| j         j        }	|                     ||          \  }
}|                     ||          \  }}d}d}|r7t          |
|| j         j                  \  }}t          |
|||          \  }
}}}d}d}|rb| j         j	        | j         j
        d         z  }t          ||| j         j        | j         j        |          \  }}t          ||||          \  }}}}|                    d          }t          j        | j                            |dd          |
|gd          }|
                    d          }d}|&|$t          j        |ddddf         ||gd          }|                                }d}||                     ||          }|                     |||||	          }|d         }| j        |                     |          }|dddd|z   f         }|ddd|z   df         }|	s|||||||f|dd         z   S t-          ||||||||j        |j        	  	        S )	a  
        Returns:

        Examples:

        ```python
        >>> from transformers import TvltProcessor, TvltModel
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))

        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltModel.from_pretrained("ZinengTang/tvlt-base")

        >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt")

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```N)rD   rE   )ra   r   )rS   rE   rT   rU   r   )r   r   r1  r2  )	r   r   r   r   r    r!   r"   r#   r$   )rz   r   r1  use_return_dictrO  rP  rJ   pixel_mask_ratiorh   r   r   rW   audio_mask_ratioaudio_mask_typer   r)   catrR  rP   get_extended_attention_maskrQ  rT  r   r#   r$   )r|   rC   rR   rD   rS   r_  r`  r   r1  r2  pixel_embedding_outputaudio_embedding_outputr   r!   pixel_mask_noisepixel_len_keepr    r"   r   audio_mask_noiseaudio_len_keeprF   embedding_outputmasked_pixel_lenr   input_shapeextended_attention_maskencoder_outputssequence_outputpixel_sequence_outputaudio_sequence_outputs                                  r0   r   zTvltModel.forward  sW   J 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B]-1-B-B<Q[-\-\*
-1-B-B<Q[-\-\*
 !  		/H&:$+Jf0 0 0,n Xf&  *	X X XT"J0ACT !  	#{;t{?[\]?^^/H&%;7+5)0 0 0,n Xf&  *	X X XT"J0ACT "&&q))
 9&&z1a88:PRhikl
 
 266q99!j&<"Y
111bqb5(9:z'RTUVVN&++--"&%&*&F&F~Wb&c&c#,,2/!5# ' 
 
 *!,>%"nn_==O /1q;K7K3K0K L /17G3G3I3I0I J 		$%%!!!!  #$ $ -$9$9////)7&1

 

 

 
	
r/   )NNFFNNN)r%   r&   r'   ro   rW  r[  r   TVLT_INPUTS_DOCSTRINGr   r   _CONFIG_FOR_DOCr)   r*   r   boolr   r-   r   r   r   s   @r0   rM  rM    sk       
    $^ ^ ^C C C +*+@AA?YYY
 3726  ,0/3&*@
 @
'@
 '@
 U./	@

 U./@
 @
 @
 $D>@
 'tn@
 d^@
 
uU&'8	9@
 @
 @
 ZY BA@
 @
 @
 @
 @
r/   rM  c                   ,     e Zd Z fdZ	 	 	 ddZ xZS )TvltDecoderc                    t                                                       t          |          |j        _        |j        _        |j        _        |j	        _
        t          j        fdt          |j                  D                       | _        t          j        |j        |j                  | _        d| _        || _        d S )Nc                 .    g | ]}t                    S r.   r"  )r#  r$  decoder_configs     r0   r%  z(TvltDecoder.__init__.<locals>.<listcomp>9  s!    XXX1Y~&&XXXr/   r  F)rn   ro   r   decoder_hidden_sizeru   decoder_num_hidden_layersr(  decoder_num_attention_headsr   decoder_intermediate_sizer  r   r&  r'  decoder_layersr  r  rT  r*  rz   )r|   rz   r}  r}   s     @r0   ro   zTvltDecoder.__init__0  s    !&))%+%?"+1+K(-3-O*+1+K( mXXXXf6V0W0WXXX
 
 f&@fF[\\\&+#r/   FTc                 >   |rdnd }|rdnd }t          | j                  D ]0\  }}|r||fz   } |||          }	|	d         }|r||	d         fz   }1|r||fz   }|                     |          }
|st          d |
||fD                       S t	          |
||          S )Nr.   r  r   r   c              3      K   | ]}||V  	d S r   r.   r-  s     r0   r/  z&TvltDecoder.forward.<locals>.<genexpr>]  s(      ffqXYXeXeXeXeXeffr/   )r3   r#   r$   )r0  r  rT  r-   r2   )r|   r#   r   r1  r2  r3  r4  r5  r6  r8  r3   s              r0   r   zTvltDecoder.forwardA  s    #7@BBD$5?bb4()<== 		P 		POA|# I$58H$H!(LJ[\\\M)!,M  P&9]1=M<O&O# 	E 1]4D D .. 	gffV->@S$Tffffff >O\oppppr/   )FFTr9  r   s   @r0   rz  rz  /  s_            (  "q q q q q q q qr/   rz  zTThe TVLT Model transformer with the decoder on top for self-supervised pre-training.c                       e Zd Z fdZd Zd Zd Zd Zd Z e	e
           eee          	 	 	 	 	 	 	 	 dd	ej        d
ej        deej                 deej                 deej                 deej                 deej                 dee         dee         dee         deeej                 ef         fd                        Z xZS )TvltForPreTrainingc                    t                                          |           || _        |j        | _        |j        | _        | j        s| j        st          d          t          |          | _        | j        rt          |          | _	        | j        rt          j        |j        |j        d          | _        t          j        t!          j        dd|j                            | _        t          j        t!          j        dd|j                            | _        t)          |          | _        |j        }|j        }| j        j        j        }t          j        t!          j        d||                    | _        t          j        t!          j        d|j        |                    | _        t          j        t!          j        dd|                    | _        | j        j        j        }|j        |j        d         z  }t          j        t!          j        d||z  |                    | _         t          j        t!          j        d||                    | _!        t          j        t!          j        dd|                    | _"        | j        j#        d         dz  | j        j$        z  }tK          ||          | _&        | j        j        d         | j        j        d         z  | j        j'        z  }tK          ||          | _(        || _        || _        || _)        |j#        | _#        |j        | _        | *                                 d S )Nz;Must set at least one of matching task and MAE task to trueTr   r   r   r=   )+rn   ro   rz   task_matchingtask_maer   rM  r<  TvltMatchingHeadmatching_headr   r   ru   r~  encoder_to_decoderrs   r)   rt   pixel_mask_tokenaudio_mask_tokenrz  decoderrw   rO  rr   decoder_pixel_pos_embeddecoder_temporal_embeddecoder_pixel_type_embedrP  r   r   r   decoder_audio_pos_embeddecoder_freq_embeddecoder_audio_type_embedr   r   TvltMAEHeadpixel_mae_headr   audio_mae_headr   rU  )
r|   rz   r~  rw   rr   num_audio_patchesr   pixel_mae_output_dimaudio_mae_output_dimr}   s
            r0   ro   zTvltForPreTraining.__init__f  s      #1" 	\dm 	\Z[[[f%%	 	:!1&!9!9D= #	<&(i0BFD^ei&j&j&jD#$&LQ6C]1^1^$_$_D!$&LQ6C]1^1^$_$_D!&v..DL"("<*J$(I$>$T!+-<AG\^q8r8r+s+sD(*,,u{1fFWYl7m7m*n*nD',.LQK^9_9_,`,`D) $	 : F%6&:QRS:TT+-<A04DDFYZZ, ,D( ')l5;qBRTg3h3h&i&iD#,.LQK^9_9_,`,`D)#';#?#Ba#G$+Jh#h "-f6J"K"KD,Q/$+2Nq2QQTXT_Trr ! #.f6J"K"KD(DO)>D&$4D!$*$;D!$*$;D! 	r/   c           
         |j         \  }}}}}|j         d         | j        d         z  }|j         d         | j        d         z  }|                    ||||| j        d         || j        d         f          }	t          j        d|	          }	|	                    |||z  |z  | j        d         | j        d         z  |z  f          }	|	S )zJ
        pixel_values: [batch_size, num_frames, 3, height, width]
        r   r   r   r   r@   zntchpwq->nthwpqc)r@   r   r   r)   einsum)
r|   rC   rF   rw   r   r   r   num_patches_heightnum_patches_widthpatchified_pixel_valuess
             r0   patchify_pixelz!TvltForPreTraining.patchify_pixel  s     ?K>P;
Jfe)/2d6KA6NN(.q1T5J15MM"."6"6"%a(!%a( #7 
#
 
#
 #(,/ACZ"["["9"A"A"%66C%a(4+@+CClR #B #
 #
 '&r/   c           	      \   |j         \  }}}}|| j        d         z  }|| j        d         z  }|                    |||| j        d         || j        d         f          }t          j        d|          }|                    |||z  | j        d         | j        d         z  |z  f          }|S )z>
        audio_values: [batch_size, 1, height, width]
        r   r   r  znchpwq->nhwpqc)r@   r   r   r)   r  )	r|   rR   rF   r   r   r   r  r  patchified_audio_valuess	            r0   patchify_audioz!TvltForPreTraining.patchify_audio  s     3?2D/
L&%#t'<Q'??!T%:1%=="."6"6"%a(!%a( #7 	#
 	#
 #(,/?AX"Y"Y"9"A"A"%66%a(4+@+CClR #B #
 #
 '&r/   c                     |                      |          }||z
  dz  }|                    d          }||z                                  |                                z  }|S Nr=   rN   rY   )r  r?  sum)r|   rC   pixel_predictionsmaskr  r6   s         r0   pixel_mae_lossz!TvltForPreTraining.pixel_mae_loss  `    "&"5"5l"C"C!$;;AyyRy  t  ""TXXZZ/r/   c                     |                      |          }||z
  dz  }|                    d          }||z                                  |                                z  }|S r  )r  r?  r  )r|   rR   audio_predictionsr  r  r6   s         r0   audio_mae_lossz!TvltForPreTraining.audio_mae_loss  r  r/   c           	         |j         \  }}}|                    ||j         d         |z
  d          }t          j        ||gd          }t          j        |d|                    d                              dd|                    }|S )Nr   rY   rN   r[   )r@   rP   r)   rf  r^   rO   )	r|   
mask_tokenr`   rd   rF   
seq_lengthrZ   mask_tokenspadded_sequences	            r0   concatenate_maskz#TvltForPreTraining.concatenate_mask  s    &.n#
J ''
K4Ea4H:4UWXYY)X{$;CCC,+*?*?*C*C*J*J1aQT*U*U
 
 
 r/   r\  NrC   rR   rD   rS   labelspixel_values_mixedpixel_mask_mixedr   r1  r2  r   c                    |
|
n| j         j        }
d}| j        r|t          d          |t          d          |                     ||||||	|
          }|d         }|                     |          }t                      } ||                    d          |                    d                    }||z  }d}d}| j        rd| j	        r\|                     ||||dd||	|
		  	        }|
r|j
        n|d
         }|
r|j        n|d         }|
r|j        n|d         }|
r|j        n|d         }|
r|j        n|d         }|
r|j        n|d         }|                     |          }|                     |          }|                    d
          }|                     | j        ||          }|| j                            d
|d
          z   }|t-          j        | j        ddd|f         | j        d
          z   }|| j        z   }|                     |          }|                     |j                  }|                     | j        ||          }|                    d
          | j        z  }|| j                             d
|d
          z   }|t-          j        | j!        ddd|f         | j        d
          z   }|| j"        z   }|                     |          }| #                    |j                  }| $                    |||          | %                    |||          z   }||z  }|
s|||f|dd         z   }||f|z   n|S tM          |||||j'        |j(                  S )aF  
        pixel_values_mixed (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
            Pixel values that mix positive and negative samples in Tvlt vision-audio matching. Audio values can be
            obtained using [`TvltProcessor`]. See [`TvltProcessor.__call__`] for details.

        pixel_mask_mixed (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel masks of pixel_values_mixed. Pixel values mixed can be obtained using [`TvltProcessor`]. See
            [`TvltProcessor.__call__`] for details.

        labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the vision audio matching loss. Indices should be in `[0, 1]`. num_labels has to be 1.

        Return:

        Examples:

        ```python
        >>> from transformers import TvltProcessor, TvltForPreTraining
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> images_mixed = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))
        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltForPreTraining.from_pretrained("ZinengTang/tvlt-base")
        >>> input_dict = processor(
        ...     images, audio, images_mixed, sampling_rate=44100, mask_pixel=True, mask_audio=True, return_tensors="pt"
        ... )

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```Nr>  zMatching task requires labelsz)Matching task requires pixel_values_mixedrD   rS   r   r1  r2  r   rN   T)rD   rS   r_  r`  r   r1  r2  r   r=   r   r         rY      )r6   r7   r8   r9   r#   r$   ))rz   rb  r  r   r<  r  r   rQ   r  trainingr   r   r   r    r!   r"   r  r   r  r  r  rP   r)   r   r  rr   r  r  r  r3   r  r   r  r  r  r  r  r  r5   r#   r$   ) r|   rC   rR   rD   rS   r  r  r  r   r1  r2  
total_lossr   rs  r7   loss_fctr6   r8   r9   rt  ru  r   r    r!   r"   pixel_decoder_inputaudio_decoder_inputrw   pixel_decoder_outputsrV   audio_decoder_outputsr   s                                    r0   r   zTvltForPreTraining.forward  sI   b &1%<kk$+B]
 	~ !@AAA!) !LMMMii"+%"3%9'    G &ajO"00AAO(**H8O0044fkk"ooFFD$J= 0	T] 0	ii%%"3%9'   
 
G HS$bG$C$CX_`aXb!GR$bG$C$CX_`aXb!=H X 9 9gVWj=H X 9 9gVWj=H X 9 9gVWj=H X 9 9gVWj"&"9"9%# # #'"9"9%# # &**1--J"&"7"78MObdu"v"v"58T8[8[\]_ikl8m8m"m"58O+AAA{
{N;T=W]^9 9 9 # #68U"U$(LL1D$E$E!../D/KLLL"&"7"78MObdu"v"v277::d>SS"58O8V8VWXZjlm8n8n"n"58O,QQQ0A1A0A-ABDDY_`9 9 9 # #68U"U$(LL1D$E$E!../D/KLLL&&|\CTUUX\XkXkl,=Y Y D $J 	L%|\BWQRR[PF/3/?ZMF**VK'+%%!/)
 
 
 	
r/   )NNNNNNNN)r%   r&   r'   ro   r  r  r  r  r  r   rv  r   r5   rw  r)   r*   r   r,   rx  r   r-   r   r   r   s   @r0   r  r  a  s       
4 4 4 4 4l' ' '8' ' '6       +*+@AA+CRabbb
 3726-1:>8<,0/3&*H
 H
'H
 'H
 U./	H

 U./H
 )*H
 %U%67H
 #5#45H
 $D>H
 'tnH
 d^H
 
uU&')AA	BH
 H
 H
 cb BAH
 H
 H
 H
 H
r/   r  c                   $     e Zd Z fdZd Z xZS )
TvltPoolerc                     t                                                       t          j        |j        |j                  | _        t          j                    | _        d S r   )rn   ro   r   r   ru   r   Tanh
activationr{   s     r0   ro   zTvltPooler.__init__x  sC    Yv163EFF
'))r/   c                 r    |d d df         }|                      |          }|                     |          }|S )Nr   )r   r  )r|   r#   first_token_tensorpooled_outputs       r0   r   zTvltPooler.forward}  s>    *111a40

#56666r/   r9  r   s   @r0   r  r  w  sG        $ $ $ $ $
      r/   r  c                   $     e Zd Z fdZd Z xZS )r  c                     t                                                       t          |          | _        t	          j        |j        d          | _        d S rm   )rn   ro   r  poolerr   r   ru   fcr{   s     r0   ro   zTvltMatchingHead.__init__  sB     (()F.22r/   c                 V    |                      |                     |                    }|S r   )r  r  r  s     r0   r   zTvltMatchingHead.forward  s%    M : :;;r/   r9  r   s   @r0   r  r    sG        3 3 3 3 3
      r/   r  c                   &     e Zd Zd fd	Zd Z xZS )r  Nc                     t                                                       || _        t          j        |j        |          | _        d S r   )rn   ro   rz   r   r   r~  r  )r|   rz   
output_dimr}   s      r0   ro   zTvltMAEHead.__init__  s:    y!;ZHHr/   c                 0    |                      |          }|S r   )r  r  s     r0   r   zTvltMAEHead.forward  s    ]33r/   r   r9  r   s   @r0   r  r    sR        I I I I I I
      r/   r  z
    Tvlt Model transformer with a classifier head on top (an MLP on top of the final hidden state of the [CLS] token)
    for audiovisual classification tasks, e.g. CMU-MOSEI Sentiment Analysis and Audio to Video Retrieval.
    c                   @    e Zd Z fdZ ee           eee          	 	 	 	 	 	 dde	j
        de	j
        dee	j
                 dee	j
                 dee         d	ee         d
ee         dee	j                 deee	j
                 ef         fd                        Z xZS ) TvltForAudioVisualClassificationc           	         t                                          |           t          |          | _        t	          j        t	          j        |j        |j        dz            t	          j        |j        dz  |j	                  t	          j
                    t	          j        |j        dz  |j                            | _        || _        |                                  d S )Nr=   r  )rn   ro   rM  r<  r   
Sequentialr   ru   r  r  GELU
num_labels
classifierrz   rU  r{   s     r0   ro   z)TvltForAudioVisualClassification.__init__  s       f%%	 -If(&*<q*@AAL+a/V5JKKKGIIIf(1,f.?@@	
 
  	r/   r\  NrC   rR   rD   rS   r   r1  r2  r  r   c	           	         ||n| j         j        }|                     |||||||          }	|	d         dddf         }
|                     |
          }d}|U| j         j        dk    rt                      } |||          }n*| j         j        dk    rt                      } |||          }|s|f|	dd         z   }||f|z   n|S t          |||	j        |	j	                  S )a  
        labels (`torch.LongTensor` of shape `(batch_size, num_labels)`, *optional*):
            Labels for computing the audiovisual loss. Indices should be in `[0, ..., num_classes-1]` where num_classes
            refers to the number of classes in audiovisual tasks.

        Return:

        Examples:
        ```python
        >>> from transformers import TvltProcessor, TvltForAudioVisualClassification
        >>> import numpy as np
        >>> import torch

        >>> num_frames = 8
        >>> images = list(np.random.randn(num_frames, 3, 224, 224))
        >>> audio = list(np.random.randn(10000))
        >>> processor = TvltProcessor.from_pretrained("ZinengTang/tvlt-base")
        >>> model = TvltForAudioVisualClassification.from_pretrained("ZinengTang/tvlt-base")
        >>> input_dict = processor(images, audio, sampling_rate=44100, return_tensors="pt")

        >>> outputs = model(**input_dict)
        >>> loss = outputs.loss
        ```Nr  r   
regressionclassificationr   )r6   r3   r#   r$   )
rz   rb  r<  r  	loss_typer
   r	   r   r#   r$   )r|   rC   rR   rD   rS   r   r1  r2  r  r   rs  r3   r6   r  r   s                  r0   r   z(TvltForAudioVisualClassification.forward  s/   H &1%<kk$+B]))!!/!5#  
 
 "!*QQQT*11{$44"99x//&*:::+--x// 	FY,F)-)9TGf$$vE'!/)	
 
 
 	
r/   )NNNNNN)r%   r&   r'   ro   r   rv  r   r   rw  r)   r*   r   rx  r,   r   r-   r   r   r   s   @r0   r  r    sD           " +*+@AA+CRabbb
 3726,0/3&*-1B
 B
'B
 'B
 U./	B

 U./B
 $D>B
 'tnB
 d^B
 )*B
 
uU&')AA	BB
 B
 B
 cb BAB
 B
 B
 B
 B
r/   r  )rM  r  r  r;  )Nr:   )Nr:   rK   rL   r   )Gr(   collections.abcr   r   copyr   dataclassesr   typingr   r   r)   r   torch.nnr   r	   r
   activationsr   modeling_layersr   modeling_outputsr   r   modeling_utilsr   pytorch_utilsr   r   utilsr   r   r   r   r   configuration_tvltr   
get_loggerr%   loggerrw  _CHECKPOINT_FOR_DOCr   r2   r5   rJ   rW   rh   Modulerj   r   rp   r   r   r   r   r   r	  r  r  r;  TVLT_START_DOCSTRINGrv  rM  rz  r  r  r  r  r  __all__r.   r/   r0   <module>r     sO                ! ! ! ! ! ! " " " " " " " "        A A A A A A A A A A " " " " " " : : : : : : J J J J J J J J . . . . . . R R R R R R R R              + * * * * * 
	H	%	%,  %? %? %? %? %?k %? %? %?P ? ? ? ? ? ? ? ?, ? ? ? ? ?{ ? ? ?B      $F F F F:+ + + + +") + + +6+ + + + +") + + +:& & & & &ry & & &R) ) ) ) )ry ) ) )X9 9 9 9 9	 9 9 9x    RY   $    BI   D    ry           # # # # #* # # #L)
 )
 )
 )
 )
") )
 )
 )
X* * * * */ * * *0	 * Z d `
 `
 `
 `
 `
# `
 `
	 `
F/q /q /q /q /q") /q /q /qd Z O
 O
 O
 O
 O
, O
 O
	 O
d
 
 
 
 
 
 
 
    ry       ")      V
 V
 V
 V
 V
': V
 V
 V
r i
h
hr/   