
     `i                     @   d dl Z d dlmZ d dlmZmZmZ d dlZd dlmZ ddl	m
Z
 ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZmZ ddlmZmZ ddlmZ ddlmZmZmZmZ ddl m!Z! ddl"m#Z#m$Z$ e ed           G d de                                  Z%e ed           G d de                                  Z& ed           G d dej'                              Z( G d dej'                  Z) G d d ej'                  Z*	 dEd"ej'        d#ej+        d$ej+        d%ej+        d&eej+                 d'e,d(e,fd)Z- G d* d+ej'                  Z. G d, d-ej'                  Z/ G d. d/ej'                  Z0 G d0 d1e          Z1 G d2 d3ej'                  Z2 G d4 d5ej'                  Z3 G d6 d7ej4                  Z5 G d8 d9e          Z6d:ej+        d;e7fd<Z8 G d= d>e6          Z9 ed?           G d@ dAe6                      Z:e G dB dCe6e                      Z;g dDZ<dS )F    N)	dataclass)CallableOptionalUnion)nn   )ACT2FN)Cache)GenerationMixin)use_kernel_forward_from_hub)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPast)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)ModelOutputTransformersKwargsauto_docstringcan_return_tuple   )	AutoModel   )Ovis2ConfigOvis2VisionConfigzJ
    Base class for Llava outputs, with hidden states and attentions.
    )custom_introc                   8    e Zd ZU dZdZeej                 ed<   dS )Ovis2ModelOutputWithPasta  
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`torch.FloatTensor`, *optional*):
        A `torch.FloatTensor` of size `(batch_size, num_images, sequence_length, hidden_size)`.
        image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
    Nimage_hidden_states)	__name__
__module____qualname____doc__r   r   torchFloatTensor__annotations__     |/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/transformers/models/ovis2/modeling_ovis2.pyr   r   *   s7         	 	 8<%"34;;;;;r(   r   zQ
    Base class for Ovis2 causal language model (or autoregressive) outputs.
    c                       e Zd ZU dZdZeej                 ed<   dZ	eej                 ed<   dZ
ee         ed<   dZeeej                          ed<   dZeeej                          ed<   dZeej                 ed<   dS )	Ovis2CausalLMOutputWithPastaA  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
        Language modeling loss (for next-token prediction).
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
        Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
    past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
        It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).

        Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
        `past_key_values` input) to speed up sequential decoding.
    image_hidden_states (`torch.FloatTensor`, *optional*):
        A `torch.FloatTensor` of size (batch_size * num_patches, num_images, sequence_length, hidden_size)`.
        image_hidden_states of the model produced by the vision encoder and after projecting the last hidden state.
    Nlosslogitspast_key_valueshidden_states
attentionsr   )r    r!   r"   r#   r,   r   r$   r%   r&   r-   r.   r
   r/   tupler0   r   r'   r(   r)   r+   r+   ?   s           )-D(5$
%,,,*.FHU&'...'+OXe_+++8<M8E%"345<<<59Ju0129997;%"34;;;;;r(   r+   RMSNormc                   ,     e Zd Zd fd	Zd Zd Z xZS )Ovis2RMSNormư>c                     t                                                       t          j        t	          j        |                    | _        || _        dS )z;
        Ovis2RMSNorm is equivalent to T5LayerNorm
        N)super__init__r   	Parameterr$   onesweightvariance_epsilon)selfhidden_sizeeps	__class__s      r)   r8   zOvis2RMSNorm.__init___   sD     	l5:k#:#:;; #r(   c                    |j         }|                    t          j                  }|                    d                              dd          }|t          j        || j        z             z  }| j        |                    |          z  S )Nr   Tkeepdim)	dtypetor$   float32powmeanrsqrtr<   r;   )r=   r/   input_dtypevariances       r)   forwardzOvis2RMSNorm.forwardg   s|    #)%((77 $$Q'',,R,>>%Ht?T4T(U(UU{]--k::::r(   c                 H    t          | j        j                   d| j         S )Nz, eps=)r1   r;   shaper<   r=   s    r)   
extra_reprzOvis2RMSNorm.extra_reprn   s&    )**II$2GIIIr(   )r5   )r    r!   r"   r8   rM   rQ   __classcell__r@   s   @r)   r4   r4   ]   sb        $ $ $ $ $ $; ; ;J J J J J J Jr(   r4   c                   $     e Zd Z fdZd Z xZS )Ovis2VisionMLPc                    t                                                       || _        |j        | _        |j        | _        t          j        | j        | j        |j                  | _        t          j        | j        | j        |j                  | _	        t          j        | j        | j        |j                  | _
        t          |j                 | _        d S Nbiasr7   r8   configr>   intermediate_sizer   Linearmlp_bias	gate_projup_proj	down_projr	   
hidden_actact_fnr=   r[   r@   s     r)   r8   zOvis2VisionMLP.__init__s       !-!'!94#3T5KRXRabbby!143IPVP_```4#94;KRXRabbbV./r(   c                     |                      |                     |                     |                    |                     |          z            }|S Nra   rc   r_   r`   r=   xra   s      r)   rM   zOvis2VisionMLP.forward}   A    NN4;;t~~a/@/@#A#ADLLQROO#STT	r(   r    r!   r"   r8   rM   rR   rS   s   @r)   rU   rU   r   G        0 0 0 0 0      r(   rU   c                   H     e Zd Zdef fdZdej        dej        fdZ xZ	S )Ovis2VisionEmbeddingsr[   c                 R   t                                                       || _        |j        | _        |j        | _        |j        | _        t          j        |j	        | j        | j        | j        d          | _
        | j        | j        z  dz  | _        | j        | _        t          j        | j        | j                  | _        |                     dt!          j        | j                                      d          d           t'          |j        |j                  | _        d S )Nvalid)in_channelsout_channelskernel_sizestridepaddingr   position_ids)r   rB   F)
persistent)r7   r8   r[   r>   	embed_dim
image_size
patch_sizer   Conv2dnum_channelspatch_embeddingnum_patchesnum_positions	Embeddingposition_embeddingregister_bufferr$   arangeexpandr4   rms_norm_epsrms_normrd   s     r)   r8   zOvis2VisionEmbeddings.__init__   s    + + +!y+? 
  
  
 !Ot>1D!-"$,t/A4>"R"R^U\$:L-M-M-T-TU\-]-]joppp$V%79LMMr(   pixel_valuesreturnc                 0   | j         j        j        }|                      |                    |                    }|                    d                              dd          }|                     |          }||                     | j                  z   }|S )NrE   r   r   )	r~   r;   rE   rF   flatten	transposer   r   rw   )r=   r   target_dtypepatch_embeds
embeddingss        r)   rM   zOvis2VisionEmbeddings.forward   s    +28++LOO,O,O,OPP!))!,,66q!<<
]]:..
$"9"9$:K"L"LL
r(   )
r    r!   r"   r   r8   r$   r%   TensorrM   rR   rS   s   @r)   ro   ro      ss        N0 N N N N N N*E$5 %,        r(   ro           modulequerykeyvalueattention_maskscalingdropoutc                    t          j        ||                    dd                    |z  }|||z   }t          j                            |dt           j                                      |j                  }t          j        	                    ||| j
                  }t          j        ||          }	|	                    dd                                          }	|	|fS )NrB   )dimrE   )ptrainingr   r   )r$   matmulr   r   
functionalsoftmaxrG   rF   rE   r   r   
contiguous)
r   r   r   r   r   r   r   kwargsattn_weightsattn_outputs
             r)   eager_attention_forwardr      s     <s}}R'<'<==GL!#n4=((2U](SSVVW\WbccL=((6?([[L,|U33K''1--88::K$$r(   c            
            e Zd ZdZ fdZ	 ddej        deej                 deej        eej                 f         fdZ	 xZ
S )	Ovis2VisionAttention=Multi-headed attention from 'Attention Is All You Need' paperc                    t                                                       || _        |j        | _        |j        | _        | j        | j        z  | _        | j        | j        z  | j        k    r t          d| j         d| j         d          | j        dz  | _	        |j
        | _        d| _        t          j        | j        | j        |j                  | _        t          j        | j        | j        |j                  | _        t          j        | j        | j        |j                  | _        t          j        | j        | j        |j                  | _        d S Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).g      FrX   r7   r8   r[   r>   ry   num_attention_heads	num_headshead_dim
ValueErrorscaleattention_dropoutr   	is_causalr   r]   qkv_biask_projv_projq_projout_projrd   s     r)   r8   zOvis2VisionAttention.__init__   0   +3$.8=4>)T^;;'dn ' 'N' ' '   ]D(
/iV_UUUiV_UUUiV_UUU	$.$.vWWWr(   Nr/   r   r   c           
         |j         \  }}}|                     |          }|                     |          }|                     |          }	|                    ||| j        | j                                      dd          }|                    ||| j        | j                                      dd          }|	                    ||| j        | j                                      dd          }	t          }
| j	        j
        dk    rt          | j	        j
                 }
 |
| |||	|| j        | j        | j        sdn| j                  \  }}|                    |||                                          }|                     |          }||fS z#Input shape: Batch x Time x Channelr   r   eagerr   )r   r   r   rO   r   r   r   viewr   r   r   r   r[   _attn_implementationr   r   r   r   r   reshaper   r   r=   r/   r   r   
batch_size
seq_lengthry   querieskeysvaluesattention_interfacer   r   s                r)   rM   zOvis2VisionAttention.forward   y    -:,?)
J	++m,,{{=))]++,,z:t~t}UU__`acdeeyyZOOYYZ[]^__ZT^T]SS]]^_abcc(?;+w66"9$+:Z"[$7$7nJ#}>CC$,	%
 	%
 	%
!\ "))*j)LLWWYYmmK00L((r(   rg   r    r!   r"   r#   r8   r$   r   r   r1   rM   rR   rS   s   @r)   r   r              GGX X X X X, 26$) $)|$) !.$)
 
u|Xel33	4$) $) $) $) $) $) $) $)r(   r   c                   $     e Zd Z fdZd Z xZS )Ovis2MLPc                    t                                                       || _        |j        | _        |j        | _        t          j        | j        | j        |j                  | _        t          j        | j        | j        |j                  | _	        t          j        | j        | j        |j                  | _
        t          |j                 | _        d S rW   rZ   rd   s     r)   r8   zOvis2MLP.__init__   re   r(   c                     |                      |                     |                     |                    |                     |          z            }|S rg   rh   ri   s      r)   rM   zOvis2MLP.forward  rk   r(   rl   rS   s   @r)   r   r      rm   r(   r   c            
            e Zd ZdZ fdZ	 ddej        deej                 deej        eej                 f         fdZ	 xZ
S )	Ovis2Attentionr   c                    t                                                       || _        |j        | _        |j        | _        | j        | j        z  | _        | j        | j        z  | j        k    r t          d| j         d| j         d          | j        dz  | _	        |j
        | _        d| _        t          j        | j        | j        |j                  | _        t          j        | j        | j        |j                  | _        t          j        | j        | j        |j                  | _        t          j        | j        | j        |j                  | _        d S r   r   rd   s     r)   r8   zOvis2Attention.__init__
  r   r(   Nr/   r   r   c           
         |j         \  }}}|                     |          }|                     |          }|                     |          }	|                    ||| j        | j                                      dd          }|                    ||| j        | j                                      dd          }|	                    ||| j        | j                                      dd          }	t          }
| j	        j
        dk    rt          | j	        j
                 }
 |
| |||	|| j        | j        | j        sdn| j                  \  }}|                    |||                                          }|                     |          }||fS r   r   r   s                r)   rM   zOvis2Attention.forward  r   r(   rg   r   rS   s   @r)   r   r     r   r(   r   c            	       v     e Zd Zdef fdZ	 d	dej        deej                 dee	         dej        fdZ
 xZS )
Ovis2VisionEncoderLayerr[   c                    t                                                       t          |          | _        t	          |          | _        t          |j        |j                  | _	        t          |j        |j                  | _
        d S rg   )r7   r8   r   	attentionr   ffnr4   r>   r   	rms_norm1	rms_norm2rd   s     r)   r8   z Ovis2VisionEncoderLayer.__init__E  si    '//F##%f&8&:MNN%f&8&:MNNr(   Nr/   r   r   r   c                     |                      |          } | j        d||d|\  }}||z   }|                     |          }|                     |          }||z   }|S )N)r/   r   r'   )r   r   r   r   )r=   r/   r   r   norm_hidden_statesr   _
mlp_outputs           r)   rM   zOvis2VisionEncoderLayer.forwardL  sy     "^^M::'r6HYgrrkqrrQ%3!^^M::XX011
%
2r(   rg   )r    r!   r"   r   r8   r$   r   r   r   r   rM   rR   rS   s   @r)   r   r   D  s        O0 O O O O O O 26 | !. +,	
 
       r(   r   c            	            e Zd ZdZdef fdZee	 d	dee	j
                 dee         defd                        Z xZS )
Ovis2VisionEncoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`Ovis2VisionEncoderLayer`].

    Args:
        config: Ovis2VisionConfig
    r[   c                     t                                                       | _        t          j        fdt          j                  D                       | _        d| _        d S )Nc                 .    g | ]}t                    S r'   )r   ).0r   r[   s     r)   
<listcomp>z/Ovis2VisionEncoder.__init__.<locals>.<listcomp>i  s"    $n$n$n%<V%D%D$n$n$nr(   F)	r7   r8   r[   r   
ModuleListrangenum_hidden_layerslayersgradient_checkpointingrd   s    `r)   r8   zOvis2VisionEncoder.__init__f  sa    m$n$n$n$neTZTlNmNm$n$n$noo&+###r(   Nr   r   r   c                 N    |}| j         D ]} |||fi |}t          |          S )Nlast_hidden_state)r   r   )r=   inputs_embedsr   r   r/   encoder_layers         r)   rM   zOvis2VisionEncoder.forwardm  sH     &![ 	S 	SM)M-RR6RRMM????r(   rg   )r    r!   r"   r#   r   r8   r   r   r   r$   r   r   r   r   rM   rR   rS   s   @r)   r   r   ]  s         ,0 , , , , , ,  26
@ 
@ !.
@ +,	
@
 

@ 
@ 
@ ^ 
@ 
@ 
@ 
@ 
@r(   r   c                   Z     e Zd Zdef fdZe	 ddeej                 fd            Z	 xZ
S )Ovis2VisionTransformerr[   c                     t                                                       || _        t          |          | _        t          |          | _        t          |j        |j	                  | _
        d| _        d S )NF)r7   r8   r[   ro   r   r   encoderr4   r>   r   r   r   rd   s     r)   r8   zOvis2VisionTransformer.__init__}  sc    /77)&11$V%79LMM&+###r(   Nr   c                     |                      |          } | j        d||d|}|j        }|                     |          }t	          |          S )N)r   r   r   r'   )r   r   r   r   r   )r=   r   r   r   r/   encoder_outputsr   s          r)   rM   zOvis2VisionTransformer.forward  sr     55+74< ,
'),
 ,
 ,
 ,
 ,= MM*;<<1BCCCCr(   rg   )r    r!   r"   r   r8   r   r   r$   r   rM   rR   rS   s   @r)   r   r   |  s        ,0 , , , , , ,  26D D !.D D D D D D D Dr(   r   c                   <     e Zd Zdej        dej        f fdZ xZS )Ovis2VisualEmbeddingTablevisual_tokensr   c                     |j         t          j        t          j        t          j        t          j        t          j        fv r!t                                          |          S t          j	        || j
                  S rg   )rE   r$   int8int16int32int64longr7   rM   r   r;   )r=   r   r@   s     r)   rM   z!Ovis2VisualEmbeddingTable.forward  sR    5:u{EKV[V`"aaa77??=111|M4;777r(   )r    r!   r"   r$   r   rM   rR   rS   s   @r)   r   r     sO        8U\ 8el 8 8 8 8 8 8 8 8 8 8r(   r   c                   D    e Zd ZU eed<   dZdZdgZdZdZ	dZ
dZdZdZdZdS )Ovis2PreTrainedModelr[   modelTr   r.   N)r    r!   r"   r   r&   base_model_prefixsupports_gradient_checkpointing_no_split_modules_skip_keys_device_placement_supports_cache_class_supports_flash_attn_supports_flex_attn_supports_sdpa_can_compile_fullgraph_supports_attention_backendr'   r(   r)   r   r     s\         &*#/0"3 N!"&r(   r   r-   r   c                    |                      |          }|                    |d          d         }t          j        | t          j                                      ||d          }||                                z
  |z   }|S )NTrC   r   )memory_formatg      ?)r   maxr$   
zeros_likelegacy_contiguous_formatscatter_detach)r-   r   y_softindexy_hardrets         r)   hard_softmaxr    sv    ^^C  FJJsDJ))!,EfE4RSSS\\]`bgilmmF
6==??
"V
+CJr(   c                   n     e Zd ZU eed<   def fdZdej        deej	        ej	        f         fdZ
 xZS )Ovis2VisionModelr[   c                 x   t                                          |           || _        t          |          | _        |j        | _        |j        | _        t          j        |j	        |j
        z  |j
        z  | j        | j        z
  d          | _        t          j        | j        | j        z
            | _        d S NFrX   )r7   r8   r[   r   transformernum_visual_indicator_tokens
vocab_sizer   r]   r>   hidden_stridehead_linear	LayerNorm	head_normrd   s     r)   r8   zOvis2VisionModel.__init__  s       1&99+1+M( +9!558LLOd>>
 
 

 do8X&XYYr(   r   r   c           	      p    | j         |fi |}|d         }| j        j        dk    r|j        \  }}}| j        j        }t	          t          j        |                    }	|	|	z  |k    rt          d          ||	|z  z
  |z  }
t          j	        
                    |ddd|
d|
fdd          }|	|
z  }	|                    ||	|z  ||	|z  ||          }|                    dddddd          }|                    |d	||z  |z            }|                     |          }|                     |          }| j        j        d
k    r#t          j	                            |d	d          }nS| j        j        dk    rt#          |d	          }n1| j        j        dk    r!t          j	                            |d	          }|S )Nr   r   z.Token sequence length must be a perfect squareconstantr   r         rB   gumbel_argmaxT)r   hard	st_argmaxr   r   )r  r[   r  rO   intmathsqrtr   r   r   padr   permuter  r   tokenize_functiongumbel_softmaxr  r   )r=   r   r   outputsr   
num_imagesseq_len
hidden_dimr  sqrt_lpad_sizer-   
prob_tokens                r)   rM   zOvis2VisionModel.forward  s   "$"<::6::#AJ;$q((.?.E+J K5M7++,,F')) !QRRR%-)?@MQH " 1 12CaAxYZ\dEegqst u uhF 1 9 9Fm3]FmD[]jlv! ! !2 9 9!Q1a K K 1 9 9B =
 J! ! !!"344'';(O;;55f"45PPJJ[*k99%f"555JJ[*i77..v2.>>Jr(   )r    r!   r"   r   r&   r8   r$   r%   r1   r   rM   rR   rS   s   @r)   r  r    s         Z0 Z Z Z Z Z Z!E$5 !E%,X]XdJdDe ! ! ! ! ! ! ! !r(   r  zu
    The Ovis2 model which consists of a vision backbone and a language model, without a language modeling head.
    c            !           e Zd Zi Zdef fdZd Zd Zd Zd Z	de
j        de
j        fd	Zd
e
j        de
j        de
j        fdZee	 	 	 	 	 	 	 	 	 	 	 	 	 dd
ee
j                 dee
j                 dee
j                 dee
j                 dee         dee
j                 dee
j                 dee         dee         dee         dee         dee
j                 deee
j        f         deeef         fd                        Z xZS )
Ovis2Modelr[   c                 z   t                                          |           t          |j                  | _        t          j        |j                  | _        t          |j        j
        |j                  | _        |j        j
        | _        |j
        | _
        |j        | _        |                                  d S rg   )r7   r8   r  vision_configvision_towerr   from_configtext_configlanguage_modelr   r  r>   visual_embeddings_tablevisual_vocab_sizevisual_indicator_token_ids	post_initrd   s     r)   r8   zOvis2Model.__init__  s       ,V-ABB'3F4FGG'@AUA`bhbt'u'u$!'!5!@ +*0*K'r(   c                 4    | j                                         S rg   )r>  get_input_embeddingsrP   s    r)   rD  zOvis2Model.get_input_embeddings  s    "77999r(   c                 :    | j                             |           d S rg   )r>  set_input_embeddingsr=   r   s     r)   rF  zOvis2Model.set_input_embeddings  s    0077777r(   c                     || _         d S rg   r>  r=   decoders     r)   set_decoderzOvis2Model.set_decoder  s    %r(   c                     | j         S rg   rI  rP   s    r)   get_decoderzOvis2Model.get_decoder
  s    ""r(   r   r   c                    |                      |          }|j        \  }}}t          j        ||| j         j        f|j        |j        d|j                  }t          j        ||gd          }| 	                    |          }t          j
        | j        | j         j        z
  | j        t          j                                      |j                  }| 	                    |          }||fS )a  
        Obtains image last hidden states from the vision tower and apply multimodal projection.

        Args:
            pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
               The tensors corresponding to the input images.
            vision_feature_layer (`Union[int, list[int]]`, *optional*):
                The index of the layer to select the vision feature. If multiple indices are provided,
                the vision feature of the corresponding indices will be concatenated to form the
                vision features.
            vision_feature_select_strategy (`str`, *optional*):
                The feature selection strategy used to select the vision feature from the vision backbone.
                Can be one of `"default"` or `"full"`
        Returns:
            image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
        F)rE   devicerequires_gradlayoutr   r(  r   )r;  rO   r$   zerosr  rE   rP  rR  catr?  r   r@  r   rF   )	r=   r   image_featuresr   img_seq_lenr   padding_tensorvisual_indicatorvisual_indicator_featuress	            r)   get_image_featureszOvis2Model.get_image_features  s    ( **<88%3%9"
Kd&7&ST &!(!(
 
 
 NN#CKKK55nEE <"T%6%RR"*
 
 
 "^"
#
#	 	
 %)$@$@AQ$R$R!888r(   	input_idsr   rU  c                 \   |e| |                                  t          j        | j        j        t          j        |j                            k    }|                    d          }n|| j        j        k    }|                                }|	                    d          
                    |                              |j                  }|j        d         |j        d         z  }||                                         |                                k    rt          d| d|           |S )z
        Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
        equal to the length of multimodal features. If the lengths are different, an error is raised.
        NrE   rP  rB   r   r   z6Image features and image tokens do not match: tokens: z, features )rD  r$   tensorr[   image_token_idr   rP  allsum	unsqueeze	expand_asrF   rO   numelr   )r=   r[  r   rU  special_image_maskn_image_tokensn_image_featuress          r)   get_placeholder_maskzOvis2Model.get_placeholder_mask6  s/    !.2M$2K2K2M2MT[7uzR_Rfggg3 3 " "4!7!7!;!;!*dk.H!H+//11/99"==GGVVYYZgZnoo)/2^5I!5LL+,22448L8L8N8NNNvvvdtvv   "!r(   Nr   r   rw   r.   labels	use_cacheoutput_attentionsoutput_hidden_statesreturn_dictcache_positionlogits_to_keepc                    |	|	n| j         j        }	|
|
n| j         j        }
|d u |d uz  rt          d          | |                                 |          }|2|                     |          \  }}|                     |||          }|                    ||          }t          | j	                  D ]\  }}|[| |                                 t          j        |t          j        |j                            k    }|                    d          }n||k                        |j                  }|                                rB||                             ||                                       |j        |j                  ||<    | j        d	||||||	|
d||d
|}t)          |j        |j        |j        |j        ||nd           S )
Nz:You must specify exactly one of input_ids or inputs_embedsr   )r   rU  r]  rB   T)
r   rw   r.   r   rj  rk  rl  rm  rn  ro  )r   r.   r/   r0   r   r'   )r[   rk  rl  r   rD  rZ  rh  masked_scatter	enumeraterA  r$   r^  r   rP  r`  rF   anyrc  rE   r>  r   r   r.   r/   r0   )r=   r[  r   r   rw   r.   r   ri  rj  rk  rl  rm  rn  ro  r   rU  rY  re  ivisual_indicator_idmaskr0  s                         r)   rM   zOvis2Model.forwardN  sC   & 2C1N--TXT_Tq$8$D  $+Jj 	 -t";< 	[YZZZ 7D5577	BBM#8<8O8O]i8O8j8j5N5!%!:!:+- "; " "
 *889K^\\M*3D4S*T*T  &&$(,GD,E,E,G,G%8
S`Sghhh- - D  88B<<DD%)<<@@AUVVD88:: 1!4"=#677M0-2EFF "$' &$% 
)%+'/!5))
 
 
 
 (%7#3!/)2>2JPT
 
 
 	
r(   NNNNNNNNNNNNr   )r    r!   r"   _checkpoint_conversion_mappingr   r8   rD  rF  rL  rN  r$   r%   rZ  
LongTensorrh  r   r   r   r   r
   boolr   r)  r1   r   rM   rR   rS   s   @r)   r8  r8    sF        &("	{ 	 	 	 	 	 	: : :8 8 8& & &# # #'9''9 
	'9 '9 '9 '9R")":?:K"]b]n" " " "0  15481537+/59-1$(,0/3&*5934J
 J
E,-J
 u01J
 !.	J

 u/0J
 "%J
   12J
 )*J
 D>J
 $D>J
 'tnJ
 d^J
 !!12J
 c5</0J
  
u..	/!J
 J
 J
 ^ J
 J
 J
 J
 J
r(   r8  c            !       H    e Zd Zi ZdgZdef fdZd Zd Zde	j
        fdZd Zd	 Zd
ej        fdZed             Zed             Zed             Zee	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                 d
eej                 deej                 deej                 dee         deej                 deej                 dee         dee         dee         dee         deej                 deeej        f         deeef         fd                        Z 	 	 	 	 	 	 d  fd	Z! xZ"S )!Ovis2ForConditionalGenerationzlm_head.weightr[   c                     t                                          |           t          |          | _        t	          j        |j        |j        d          | _        | 	                                 d S r  )
r7   r8   r8  r   r   r]   r>   r  lm_headrB  rd   s     r)   r8   z&Ovis2ForConditionalGeneration.__init__  s^       ''
y!3V5FUSSSr(   c                 4    | j                                         S rg   )r   rD  rP   s    r)   rD  z2Ovis2ForConditionalGeneration.get_input_embeddings  s    z..000r(   c                 :    | j                             |           d S rg   )r   rF  rG  s     r)   rF  z2Ovis2ForConditionalGeneration.set_input_embeddings  s    
''.....r(   r   c                     | j         S rg   )r  rP   s    r)   get_output_embeddingsz3Ovis2ForConditionalGeneration.get_output_embeddings  s
    |r(   c                 :    | j                             |           d S rg   )r   rL  rJ  s     r)   rL  z)Ovis2ForConditionalGeneration.set_decoder  s    
w'''''r(   c                 4    | j                                         S rg   )r   rN  rP   s    r)   rN  z)Ovis2ForConditionalGeneration.get_decoder  s    z%%'''r(   r   c                 8    | j                             |          S )Nrq  )r   rZ  )r=   r   s     r)   rZ  z0Ovis2ForConditionalGeneration.get_image_features  s    z,,,,GGGr(   c                     | j         j        S rg   )r   r>  rP   s    r)   r>  z,Ovis2ForConditionalGeneration.language_model  s    z((r(   c                     | j         j        S rg   )r   r;  rP   s    r)   r;  z*Ovis2ForConditionalGeneration.vision_tower  s    z&&r(   c                      t          d          )NzNot needed for Ovis2)AttributeErrorrP   s    r)   multi_modal_projectorz3Ovis2ForConditionalGeneration.multi_modal_projector  s    3444r(   Nr   r[  r   rw   r.   r   ri  rj  rk  rl  rm  rn  ro  c                    |	|	n| j         j        }	|
|
n| j         j        }
 | j        d||||||||	|
d|d|}|d         }t	          |t
                    rt          | d          n|}|                     |dd|ddf                   }d}|  | j        d||| j         j	        j
        d|}t          |||j        |j        |j        |j                  S )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration

        >>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
        >>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")

        >>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
        >>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, text=prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
        "user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
        ```NT)r[  r   r   rw   r.   r   rj  rk  rl  rm  rn  r   )r-   ri  r  )r,   r-   r.   r/   r0   r   r'   )r[   rk  rl  r   
isinstancer)  slicer  loss_functionr=  r  r+   r.   r/   r0   r   )r=   r[  r   r   rw   r.   r   ri  rj  rk  rl  rm  rn  ro  r   r0  r/   slice_indicesr-   r,   s                       r)   rM   z%Ovis2ForConditionalGeneration.forward  sN   \ 2C1N--TXT_Tq$8$D  $+Jj 	 $* 
%)%+'/!5)
 
 
 
  
8B>SV8W8Wk~ot444]kmAAA}aaa,?@AA%4% f9P9[ _e D +#3!/) ' ;
 
 
 	
r(   c           	      j     t                      j        |f|||||d|}	|d         dk    r||	d<   |	S )N)r.   r   r   rn  ro  r   r   )r7   prepare_inputs_for_generation)r=   r[  r.   r   r   r   rn  ro  r   model_inputsr@   s             r)   r  z;Ovis2ForConditionalGeneration.prepare_inputs_for_generation  sg     =uww<
+')))
 
 
 
 !!! ,8L(r(   rx  )NNNNNN)#r    r!   r"   ry  _tied_weights_keysr   r8   rD  rF  r   Moduler  rL  rN  r$   r%   rZ  propertyr>  r;  r  r   r   r   rz  r   r
   r{  r   r)  r1   r+   rM   r  rR   rS   s   @r)   r}  r}    s       %'"*+{      1 1 1/ / /ry    ( ( (( ( (Hu/@ H H H H ) ) X) ' ' X' 5 5 X5  15481537+/59-1$(,0/3&*5934R
 R
E,-R
 u01R
 !.	R

 u/0R
 "%R
   12R
 )*R
 D>R
 $D>R
 'tnR
 d^R
 !!12R
 c5</0R
  
u11	2!R
 R
 R
 ^ R
n          r(   r}  )r   r8  r}  )r   )=r*  dataclassesr   typingr   r   r   r$   r   activationsr	   cache_utilsr
   
generationr   integrationsr   modeling_layersr   modeling_outputsr   r   modeling_utilsr   r   processing_utilsr   utilsr   r   r   r   autor   configuration_ovis2r   r   r   r+   r  r4   rU   ro   r   floatr   r   r   r   r   r   r   r   r   r   r)  r  r  r8  r}  __all__r'   r(   r)   <module>r     sx  ,  ! ! ! ! ! ! , , , , , , , , , ,        ! ! ! ! ! !             ) ) ) ) ) ) 7 7 7 7 7 7 9 9 9 9 9 9 H H H H H H H H F F F F F F F F & & & & & & V V V V V V V V V V V V       ? ? ? ? ? ? ? ?   
< < < < <6 < <  <   
< < < < <+ < <  <0 Y''J J J J J29 J J ('J(    RY        BI   P % %I%<% 
% <	%
 U\*% % % % % %.:) :) :) :) :)29 :) :) :)z    ry    :) :) :) :) :)RY :) :) :)z    8   2@ @ @ @ @ @ @ @>D D D D DRY D D D<8 8 8 8 8 8 8 8' ' ' ' '? ' ' ' C    1 1 1 1 1+ 1 1 1h   
g
 g
 g
 g
 g
% g
 g
 
g
T [ [ [ [ [$8/ [ [ [| R
Q
Qr(   