
     `i0                    |   d Z ddlmZ ddlmZ ddlmZmZmZ ddl	Z	ddl	m
Z
mZ ddlmZ dd	lmZmZ dd
lmZ ddlmZmZ ddlmZ ddlmZmZmZmZmZmZ ddlm Z m!Z!m"Z"  e            rddl#m$Z$  ej%        e&          Z'de	j
        de	j
        fdZ(de	j
        de	j
        fdZ)ee G d de                                  Z*de
de
fdZ+de
de
fdZ,d Z-d Z.e ed           G d  d!e                                  Z/e ed"           G d# d$e                                  Z0 G d% d&ej1                  Z2 G d' d(ej1                  Z3 G d) d*ej1                  Z4 G d+ d,ej1                  Z5 G d- d.e          Z6e G d/ d0e                      Z7 G d1 d2ej1                  Z8 G d3 d4ej1                  Z9 G d5 d6e7          Z: G d7 d8ej1                  Z; G d9 d:e7          Z<e G d; d<e7                      Z= G d= d>ej1                  Z> G d? d@ej1                  Z? G dA dBe7          Z@g dCZAdS )DzPyTorch OWLv2 model.    )	dataclass)	lru_cache)AnyOptionalUnionN)Tensornn   )ACT2FN) _create_4d_causal_attention_mask_prepare_4d_attention_mask)GradientCheckpointingLayer)BaseModelOutputBaseModelOutputWithPooling)PreTrainedModel)ModelOutputauto_docstringfilter_out_non_signature_kwargsis_vision_availablelogging	torch_int   )Owlv2ConfigOwlv2TextConfigOwlv2VisionConfig)center_to_corners_formatlogitsreturnc                     t           j                            | t          j        t          |           | j                            S )Ndevice)r	   
functionalcross_entropytorcharangelenr!   )r   s    |/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/transformers/models/owlv2/modeling_owlv2.pycontrastive_lossr(   3   s3    =&&vu|CKKPVP]/^/^/^___    
similarityc                 r    t          |           }t          |                                           }||z   dz  S )Ng       @)r(   t)r*   caption_loss
image_losss      r'   
owlv2_lossr/   8   s4    #J//L!*,,..11J:%,,r)   c                       e Zd ZU dZdZeej                 ed<   dZ	eej                 ed<   dZ
eej                 ed<   dZeej                 ed<   dZeej                 ed<   dZeed<   dZeed	<   d
ee         fdZdS )Owlv2Outputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
        Contrastive loss for image-text similarity.
    logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
        The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
        similarity scores.
    logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
        The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
        similarity scores.
    text_embeds (`torch.FloatTensor` of shape `(batch_size * num_max_text_queries, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`Owlv2TextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
        The image embeddings obtained by applying the projection layer to the pooled output of
        [`Owlv2VisionModel`].
    text_model_output (tuple[`BaseModelOutputWithPooling`]):
        The output of the [`Owlv2TextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`Owlv2VisionModel`].
    Nlosslogits_per_imagelogits_per_texttext_embedsimage_embedstext_model_outputvision_model_outputr   c                 ^     t           fd                                 D                       S )Nc              3   t   K   | ]2}|d vr|         n!t          |                                          V  3dS )r7   r8   Ngetattrto_tuple.0kselfs     r'   	<genexpr>z'Owlv2Output.to_tuple.<locals>.<genexpr>^   c       
 
  LLLDGGRYZ^`aRbRbRkRkRmRm
 
 
 
 
 
r)   tuplekeysrB   s   `r'   r>   zOwlv2Output.to_tuple]   C     
 
 
 
YY[[
 
 
 
 
 	
r)   )__name__
__module____qualname____doc__r2   r   r$   FloatTensor__annotations__r3   r4   r5   r6   r7   r   r8   rF   r   r>    r)   r'   r1   r1   >   s          ( )-D(5$
%,,,48hu0188837OXe/0777/3K%+,33304L(5,-4444818886:3:::
%* 
 
 
 
 
 
r)   r1   r,   c                     |                                  r5| j        t          j        t          j        fv r| n|                                 S | j        t          j        t          j        fv r| n|                                 S N)	is_floating_pointdtyper$   float32float64floatint32int64int)r,   s    r'   _upcastr[   e   se     GGu}===qq17799LGU[999qqquuwwFr)   boxesc                     t          |           } | dddf         | dddf         z
  | dddf         | dddf         z
  z  S )a  
    Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates.

    Args:
        boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`):
            Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1
            < x2` and `0 <= y1 < y2`.

    Returns:
        `torch.FloatTensor`: a tensor containing the area for each box.
    N   r   r
   r   )r[   )r\   s    r'   box_arear_   n   sT     ENNE!!!Q$K%1+%%1+aaad*CDDr)   c                    t          |           }t          |          }t          j        | d d d d df         |d d d df                   }t          j        | d d d dd f         |d d dd f                   }||z
                      d          }|d d d d df         |d d d d df         z  }|d d d f         |z   |z
  }||z  }	|	|fS )Nr^   r   minr   )r_   r$   maxrb   clamp)
boxes1boxes2area1area2left_topright_bottomwidth_heightinterunionious
             r'   box_iouro      s    VEVEy4!,fQQQUm<<H9VAAAtQRRK0&ABB-@@L 8+22q299LAAAq!LAAAq$99E!!!T'NU"U*E
%-C:r)   c                 n   | ddddf         | ddddf         k                                     st          d|            |ddddf         |ddddf         k                                     st          d|           t          | |          \  }}t          j        | dddddf         |ddddf                   }t          j        | dddddf         |ddddf                   }||z
                      d          }|dddddf         |dddddf         z  }|||z
  |z  z
  S )z
    Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format.

    Returns:
        `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2)
    Nr^   z<boxes1 must be in [x0, y0, x1, y1] (corner) format, but got z<boxes2 must be in [x0, y0, x1, y1] (corner) format, but got r   ra   r   )all
ValueErrorro   r$   rb   rc   rd   )re   rf   rn   rm   top_leftbottom_rightrk   areas           r'   generalized_box_iourv      s    111abb5MVAAArrE]*//11 b`X^``aaa111abb5MVAAArrE]*//11 b`X^``aaa((JCy4!,fQQQUm<<H9VAAAtQRRK0&ABB-@@L 8+22q299L111a <111a#88D$,$&&&r)   z5
    Output type of [`Owlv2ForObjectDetection`].
    )custom_introc                   ^   e Zd ZU dZdZeej                 ed<   dZ	ee
         ed<   dZeej                 ed<   dZeej                 ed<   dZeej                 ed<   dZeej                 ed<   dZeej                 ed	<   dZeej                 ed
<   dZeed<   dZeed<   dee         fdZdS )Owlv2ObjectDetectionOutputa  
    loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
        Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
        bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
        scale-invariant IoU loss.
    loss_dict (`Dict`, *optional*):
        A dictionary containing the individual losses. Useful for logging.
    logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
        Classification logits (including no-object) for all queries.
    objectness_logits (`torch.FloatTensor` of shape `(batch_size, num_patches, 1)`):
        The objectness logits of all image patches. OWL-ViT represents images as a set of image patches where the
        total number of patches is (image_size / patch_size)**2.
    pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
        possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to retrieve the
        unnormalized bounding boxes.
    text_embeds (`torch.FloatTensor` of shape `(batch_size, num_max_text_queries, output_dim`):
        The text embeddings obtained by applying the projection layer to the pooled output of [`Owlv2TextModel`].
    image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
        Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes image
        embeddings for each patch.
    class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
        Class embeddings of all image patches. OWLv2 represents images as a set of image patches where the total
        number of patches is (image_size / patch_size)**2.
    text_model_output (tuple[`BaseModelOutputWithPooling`]):
        The output of the [`Owlv2TextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`Owlv2VisionModel`].
    Nr2   	loss_dictr   objectness_logits
pred_boxesr5   r6   class_embedsr7   r8   r   c                 ^     t           fd                                 D                       S )Nc              3   t   K   | ]2}|d vr|         n!t          |                                          V  3dS r;   r<   r?   s     r'   rC   z6Owlv2ObjectDetectionOutput.to_tuple.<locals>.<genexpr>   rD   r)   rE   rH   s   `r'   r>   z#Owlv2ObjectDetectionOutput.to_tuple   rI   r)   )rJ   rK   rL   rM   r2   r   r$   rN   rO   rz   dictr   r{   r|   r5   r6   r}   r7   r   r8   rF   r   r>   rP   r)   r'   ry   ry      s"         > )-D(5$
%,,, $Ix~$$$*.FHU&'...59x 12999.2J*+222/3K%+,33304L(5,-44404L(5,-4444818886:3:::
%* 
 
 
 
 
 
r)   ry   zL
    Output type of [`Owlv2ForObjectDetection.image_guided_detection`].
    c                       e Zd ZU dZdZeej                 ed<   dZ	eej                 ed<   dZ
eej                 ed<   dZeej                 ed<   dZeej                 ed<   dZeej                 ed<   dZeed	<   dZeed
<   dee         fdZdS )%Owlv2ImageGuidedObjectDetectionOutputa  
    logits (`torch.FloatTensor` of shape `(batch_size, num_patches, num_queries)`):
        Classification logits (including no-object) for all queries.
    image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
        Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes
        image embeddings for each patch.
    query_image_embeds (`torch.FloatTensor` of shape `(batch_size, patch_size, patch_size, output_dim`):
        Pooled output of [`Owlv2VisionModel`]. OWLv2 represents images as a set of image patches and computes
        image embeddings for each patch.
    target_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual target image in the batch
        (disregarding possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to
        retrieve the unnormalized bounding boxes.
    query_pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_patches, 4)`):
        Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
        values are normalized in [0, 1], relative to the size of each individual query image in the batch
        (disregarding possible padding). You can use [`~Owlv2ImageProcessor.post_process_object_detection`] to
        retrieve the unnormalized bounding boxes.
    class_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`):
        Class embeddings of all image patches. OWLv2 represents images as a set of image patches where the total
        number of patches is (image_size / patch_size)**2.
    text_model_output (tuple[`BaseModelOutputWithPooling`]):
        The output of the [`Owlv2TextModel`].
    vision_model_output (`BaseModelOutputWithPooling`):
        The output of the [`Owlv2VisionModel`].
    Nr   r6   query_image_embedstarget_pred_boxesquery_pred_boxesr}   r7   r8   r   c                 ^     t           fd                                 D                       S )Nc              3   t   K   | ]2}|d vr|         n!t          |                                          V  3dS r;   r<   r?   s     r'   rC   zAOwlv2ImageGuidedObjectDetectionOutput.to_tuple.<locals>.<genexpr>  rD   r)   rE   rH   s   `r'   r>   z.Owlv2ImageGuidedObjectDetectionOutput.to_tuple  rI   r)   )rJ   rK   rL   rM   r   r   r$   rN   rO   r6   r   r   r   r}   r7   r   r8   rF   r   r>   rP   r)   r'   r   r      s          8 +/FHU&'...04L(5,-4446:!23:::59x 1299948hu0188804L(5,-4444818886:3:::
%* 
 
 
 
 
 
r)   r   c                   z     e Zd Zdef fdZdej        dededej        fdZdd	ej	        d
e
dej        fdZ xZS )Owlv2VisionEmbeddingsconfigc                 b   t                                                       |j        | _        || _        |j        | _        t          j        t          j	        |j                            | _
        t          j        |j        | j        |j        |j        d          | _        |j        |j        z  dz  | _        | j        dz   | _        t          j        | j        | j                  | _        |                     dt          j        | j                                      d          d           d S )NF)in_channelsout_channelskernel_sizestridebiasr^   r   position_idsr   
persistent)super__init__
patch_sizer   hidden_size	embed_dimr	   	Parameterr$   randnclass_embeddingConv2dnum_channelspatch_embedding
image_sizenum_patchesnum_positions	Embeddingposition_embeddingregister_bufferr%   expandrB   r   	__class__s     r'   r   zOwlv2VisionEmbeddings.__init__  s    ++!|EK8J,K,KLL!y+)$ 
  
  
 #-1BBqH!-1"$,t/A4>"R"R^U\$:L-M-M-T-TU\-]-]jopppppr)   
embeddingsheightwidthr   c                    |j         d         dz
  }| j        j                            d          }|j         d         dz
  }t          j                                        s&||k    r ||k    r|                     | j                  S |ddddf         }|ddddf         }|j         d         }	|| j        z  }
|| j        z  }t          |dz            }|
                    d|||	          }|                    dddd          }t          j                            ||
|fdd	
          }|                    dddd                              dd|	          }t	          j        ||fd          S )a   
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
        images. This method is also adapted to support torch.jit tracing.

        Adapted from:
        - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
        - https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
        r   r   Nr   g      ?r
   r^   bicubicF)sizemodealign_cornersdim)shaper   weight	unsqueezer$   jit
is_tracingr   r   r   reshapepermuter	   r"   interpolateviewcat)rB   r   r   r   r   r   r   class_pos_embedpatch_pos_embedr   
new_height	new_widthsqrt_num_positionss                r'   interpolate_pos_encodingz.Owlv2VisionEmbeddings.interpolate_pos_encoding+  s    !&q)A-!4;EEaHH*03a7 y##%% 	>+*F*F6UZ??**4+<===,QQQU3,QQQU3r"t.
T_,	&}c'9::)11!5GI[]`aa)11!Q1==-33i(	 4 
 
 *11!Q1==BB1b#NNy/?;CCCCr)   Fpixel_valuesr   c                 v   |j         \  }}}}|                     |          }|                    d                              dd          }| j                            |dd          }t          j        ||gd          }	|r|	|                     |	||          z   }	n|	| 	                    | j
                  z   }	|	S )Nr^   r   r   r   )r   r   flatten	transposer   r   r$   r   r   r   r   )
rB   r   r   
batch_size_r   r   patch_embedsr}   r   s
             r'   forwardzOwlv2VisionEmbeddings.forwardT  s    '3'9$
Avu++L99#++A..88A>>+22:q"EEYl;CCC
# 	Q#d&C&CJPVX]&^&^^JJ#d&=&=d>O&P&PPJr)   F)rJ   rK   rL   r   r   r$   r   rZ   r   rN   boolr   __classcell__r   s   @r'   r   r     s        q0 q q q q q q*'D5< 'D 'DUX 'D]b]i 'D 'D 'D 'DR E$5 QU bgbn        r)   r   c            	            e Zd Zdef fdZ	 	 	 d	deej                 deej                 deej                 dej	        fdZ
 xZS )
Owlv2TextEmbeddingsr   c                 \   t                                                       t          j        |j        |j                  | _        t          j        |j        |j                  | _        | 	                    dt          j        |j                                      d          d           d S )Nr   r   Fr   )r   r   r	   r   
vocab_sizer   token_embeddingmax_position_embeddingsr   r   r$   r%   r   r   s     r'   r   zOwlv2TextEmbeddings.__init__d  s    !|F,=v?QRR"$,v/MvOa"b"b 	EL)GHHOOPWXXej 	 	
 	
 	
 	
 	
r)   N	input_idsr   inputs_embedsr   c                     ||j         d         n|j         d         }|| j        d d d |f         }||                     |          }|                     |          }||z   }|S )Nr   )r   r   r   r   )rB   r   r   r   
seq_lengthposition_embeddingsr   s          r'   r   zOwlv2TextEmbeddings.forwardn  s     -6,AY_R((}GZ[]G^
,QQQ^<L  00;;M"55lCC"%88
r)   )NNN)rJ   rK   rL   r   r   r   r$   
LongTensorrN   r   r   r   r   s   @r'   r   r   c  s        
 
 
 
 
 
 
 153759	 E,- u/0   12	
 
       r)   r   c                        e Zd ZdZ fdZdej        dedefdZ	 	 	 dd	ej        d
e	ej                 de	ej                 de	e
         deej        e	ej                 e	eej                          f         f
dZ xZS )Owlv2Attentionz=Multi-headed attention from 'Attention Is All You Need' paperc                 t   t                                                       || _        |j        | _        |j        | _        | j        | j        z  | _        | j        | j        z  | j        k    r t          d| j         d| j         d          | j        dz  | _	        |j
        | _        t          j        | j        | j                  | _        t          j        | j        | j                  | _        t          j        | j        | j                  | _        t          j        | j        | j                  | _        d S )Nz;embed_dim must be divisible by num_heads (got `embed_dim`: z and `num_heads`: z).      )r   r   r   r   r   num_attention_heads	num_headshead_dimrr   scaleattention_dropoutdropoutr	   Lineark_projv_projq_projout_projr   s     r'   r   zOwlv2Attention.__init__  s   +3$.8=4>)T^;;'dn ' 'N' ' '   ]D(
/i??i??i??	$.$.AAr)   tensorseq_lenbszc                     |                     ||| j        | j                                      dd                                          S )Nr   r^   )r   r   r   r   
contiguous)rB   r   r   r   s       r'   _shapezOwlv2Attention._shape  s<    {{3GGQQRSUVWWbbdddr)   NFhidden_statesattention_maskcausal_attention_maskoutput_attentionsr   c                    |                                 \  }}}|                     |          | j        z  }|                     |                     |          d|          }	|                     |                     |          d|          }
|| j        z  d| j        f} |                     |||          j        | } |	j        | }	 |
j        | }
|	                     d          }t          j
        ||	                    dd                    }|                                 || j        z  ||fk    r2t          d|| j        z  ||f d|                                            ||                                 |d||fk    r+t          d|d||f d|                                            |                    || j        ||          |z   }|                    || j        z  ||          }||                                 |d||fk    r+t          d|d||f d|                                            |                    || j        ||          |z   }|                    || j        z  ||          }t          j                            |d          }|r=|                    || j        ||          }|                    || j        z  ||          }nd}t          j                            || j        | j        	          }|                    |
j                  }t          j
        ||
          }|                                 || j        z  || j        fk    r5t          d
|| j        || j        f d|                                            |                    || j        || j                  }|                    dd          }|                    |||          }|                     |          }||fS )z#Input shape: Batch x Time x Channelr   r   r^   z$Attention weights should be of size z	, but is Nz!Attention mask should be of size r   )ptrainingz `attn_output` should be of size )r   r   r   r   r   r   r   r   r   r$   bmmr   rr   r	   r"   softmaxr   r   torT   r   r   )rB   r   r   r   r   r   tgt_lenr   query_states
key_statesvalue_states
proj_shapesrc_lenattn_weightsattn_weights_reshaped
attn_probsattn_outputs                    r'   r   zOwlv2Attention.forward  sF    #0"4"4"6"6Wi {{=11DJ>[[]!;!;REE
{{4;;}#=#=r3GGDN*B>
Ct{{<#>>CZP$Z_j1
(|(*5//!$$yz/C/CAq/I/IJJ3#7'"JJJ*dn8LgW^7_ * * %%''* *   !,$))++Q/III 7a'8R 7 7-22447 7   (,,S$.'7SSVkkL',,S4>-A7GTTL%""$$a'(BBB ta'8Rtt]k]p]p]r]rtt   (,,S$.'7SSVddL',,S4>-A7GTTL},,\r,BB 	)
 %1$5$5c4>7T[$\$\!055cDN6JGU\]]LL$(!]**<4<RVR_*``
  ]]<#566
i
L99#"6!OOO)CRVR_3` ) )$$&&) )  
 "&&sDNGT]SS!++Aq11!))#w	BBmmK00111r)   NNF)rJ   rK   rL   rM   r   r$   r   rZ   r   r   r   rF   r   r   r   s   @r'   r   r     s       GGB B B B B&eU\ eC ec e e e e 268<,1O2 O2|O2 !.O2  (5	O2
 $D>O2 
u|Xel3XeEL>Q5RR	SO2 O2 O2 O2 O2 O2 O2 O2r)   r   c                   B     e Zd Z fdZdej        dej        fdZ xZS )Owlv2MLPc                    t                                                       || _        t          |j                 | _        t          j        |j        |j	                  | _
        t          j        |j	        |j                  | _        d S rR   )r   r   r   r   
hidden_actactivation_fnr	   r   r   intermediate_sizefc1fc2r   s     r'   r   zOwlv2MLP.__init__  sf    #F$569V/1IJJ9V5v7IJJr)   r   r   c                     |                      |          }|                     |          }|                     |          }|S rR   )r  r  r  )rB   r   s     r'   r   zOwlv2MLP.forward  s=    //**=99//r)   )rJ   rK   rL   r   r$   r   r   r   r   s   @r'   r  r    sc        K K K K KU\ el        r)   r  c                        e Zd Zdef fdZ	 d
dej        dej        dej        dee         de	ej
                 f
d	Z xZS )Owlv2EncoderLayerr   c                 D   t                                                       |j        | _        t	          |          | _        t          j        | j        |j                  | _	        t          |          | _        t          j        | j        |j                  | _        d S Neps)r   r   r   r   r   	self_attnr	   	LayerNormlayer_norm_epslayer_norm1r  mlplayer_norm2r   s     r'   r   zOwlv2EncoderLayer.__init__   s    +'//<F<QRRRF##<F<QRRRr)   Fr   r   r   r   r   c                     |}|                      |          }|                     ||||          \  }}||z   }|}|                     |          }|                     |          }||z   }|f}|r||fz  }|S )aI  
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        )r   r   r   r   )r  r  r  r  )rB   r   r   r   r   residualr  outputss           r'   r   zOwlv2EncoderLayer.forward  s    " !((77&*nn')"7/	 '5 '
 '
#| !=0 ((77// =0 " 	'&Gr)   r   )rJ   rK   rL   r   r   r$   r   r   r   rF   rN   r   r   r   s   @r'   r  r    s        S{ S S S S S S -2& &|& &  %|	&
 $D>& 
u 	!& & & & & & & &r)   r  c                   >    e Zd ZU eed<   dZdZdgZdej	        fdZ
dS )Owlv2PreTrainedModelr   owlv2Tr  modulec                 X   | j         j        }t          |t                    rT|j        j        j                            d|dz             |j        j        j                            d|dz             nt          |t                    rt          j                            |j        d|j        dz  |z             t          j                            |j        j        |j         j        |z             t          j                            |j        j        |j         j        |z             n@t          |t                     r|j        dz  d|j         j        z  dz  z  |z  }|j        dz  |z  }t          j                            |j        j        |           t          j                            |j        j        |           t          j                            |j        j        |           t          j                            |j        j        |           nPt          |t,                    r|j         j        dz  d|j         j        z  dz  z  |z  }d|j         j        z  dz  |z  }t          j                            |j        j        |           t          j                            |j        j        |           nt          |t4                    rt          j                            |j        j        |j        dz  |z             t          j                            |j        j        |j        dz  |z             |j        j                             | j         j!                   t          |t          j"                  r=|j#        j        $                                 |j        j                             d           t          |t          j%                  rH|j        j                            d|           |j#        "|j#        j        $                                 dS dS dS )	zInitialize the weights        g{Gz?)meanstdr   )r&  r^         ?N)&r   initializer_factor
isinstancer   r   r   datanormal_r   r   r	   initr   r   r   initializer_ranger   num_hidden_layersr   r   r   r   r  r   r  r  
Owlv2Modeltext_projectiontext_embed_dimvisual_projectionvision_embed_dimlogit_scalefill_logit_scale_init_valuer  r   zero_r   )rB   r"  factorin_proj_stdout_proj_stdfc_stds         r'   _init_weightsz"Owlv2PreTrainedModel._init_weights9  s   /f122 	N").66CVd]6SSS%,199sQU9VVVV 566 	NGOOF2&BRTXBX[aBaObbbGOOF29v}?^ag?gOhhhGOOF5<&-BadjBjOkkkk// 	N!+T1q6=;Z7Z_c6cdgmmK",d2f<LGOOFM0kOBBBGOOFM0kOBBBGOOFM0kOBBBGOOFO2OEEEE)) 	N!=4d:FMDc@chl?lmpvvK&-33<vEFGOOFJ-6O:::GOOFJ-;O????
++ 		NGOO&-)4/&8     GOO(/+T1F:     #))$+*LMMMfbl++ 	*K""$$$M$$S)))fbi(( 	)M&&CV&<<<{& &&(((((	) 	)&&r)   N)rJ   rK   rL   r   rO   base_model_prefixsupports_gradient_checkpointing_no_split_modulesr	   Moduler<  rP   r)   r'   r   r   1  sV          &*#,-&)BI &) &) &) &) &) &)r)   r   c                        e Zd ZdZdef fdZ	 	 	 	 	 ddeej                 deej                 dee	         dee	         d	ee	         d
e
eef         fdZ xZS )Owlv2Encoderz
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`Owlv2EncoderLayer`].

    Args:
        config: Owlv2Config
    r   c                     t                                                       t          j        fdt	          j                  D                       | _        d| _        d S )Nc                 .    g | ]}t                    S rP   )r  )r@   r   r   s     r'   
<listcomp>z)Owlv2Encoder.__init__.<locals>.<listcomp>n  s"    $h$h$h1%6v%>%>$h$h$hr)   F)r   r   r	   
ModuleListranger.  layersgradient_checkpointingr   s    `r'   r   zOwlv2Encoder.__init__l  sY    m$h$h$h$hfNfHgHg$h$h$hii&+###r)   Nr   r   r   output_hidden_statesreturn_dictr   c                 \   ||n| j         j        }||n| j         j        }||n| j         j        }|rdnd}|rdnd}|}	| j        D ]/}
|r||	fz   } |
|	|||          }|d         }	|r||d         fz   }0|r||	fz   }|st          d |	||fD                       S t          |	||          S )a  
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`).
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:
                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.
                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        NrP   )r   r   r   c              3      K   | ]}||V  	d S rR   rP   )r@   vs     r'   rC   z'Owlv2Encoder.forward.<locals>.<genexpr>  s(      eeqWXWdWdWdWdWdeer)   )last_hidden_stater   
attentions)r   r   rJ  use_return_dictrH  rF   r   )rB   r   r   r   r   rJ  rK  encoder_statesall_attentionsr   encoder_layerlayer_outputss               r'   r   zOwlv2Encoder.forwardq  s?   > 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B]3=0:d%![ 	F 	FM# C!/=2B!B)M%"3	  M *!,M  F!/=3C2E!E 	?+}.>>N 	fee]NN$Seeeeee+>Vd
 
 
 	
r)   NNNNN)rJ   rK   rL   rM   r   r   r   r$   r   r   r   rF   r   r   r   r   s   @r'   rB  rB  c  s         ,{ , , , , , , 268<,0/3&*?
 ?
 !.?
  (5	?

 $D>?
 'tn?
 d^?
 
uo%	&?
 ?
 ?
 ?
 ?
 ?
 ?
 ?
r)   rB  c                        e Zd Zdef fdZe	 	 	 	 	 ddej        deej                 deej                 dee	         dee	         d	ee	         d
e
eef         fd            Z xZS )Owlv2TextTransformerr   c                     t                                                       || _        |j        }t	          |          | _        t          |          | _        t          j	        ||j
                  | _        d S r  )r   r   r   r   r   r   rB  encoderr	   r  r  final_layer_norm)rB   r   r   r   s      r'   r   zOwlv2TextTransformer.__init__  sf    &	-f55#F++ "YF<Q R R Rr)   Nr   r   r   r   rJ  rK  r   c                    ||n| j         j        }||n| j         j        }||n| j         j        }|                                }|                    d|d                   }|                     ||          }t          ||j        |j	                  }	|t          ||j                  }|                     |||	|||          }
|
d         }|                     |          }|t          j        |j        d         |j	                  |                    t          j                                      d                              |j	                  f         }|s||f|
dd         z   S t'          |||
j        |
j        	          S )
a|  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)
        Nr   )r   r   r    )r   r   r   r   rJ  rK  r   r   r   rO  pooler_outputr   rP  )r   r   rJ  rQ  r   r   r   r   rT   r!   r   rZ  r[  r$   r%   r   r   rZ   argmaxr   r   rP  )rB   r   r   r   r   rJ  rK  input_shaper   r   encoder_outputsrO  pooled_outputs                r'   r   zOwlv2TextTransformer.forward  s     2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B]nn&&NN2{277	),WW
 !A,]5I!
 !
 !
 %7H[\\N,,')"7/!5# ' 
 
 ,A. 112CDD *L*03<M<TUUULL##**r*22556G6NOOQ

  	L%}58KKK)/')7&1	
 
 
 	
r)   rV  )rJ   rK   rL   r   r   r   r$   r   r   r   r   rF   r   r   r   r   s   @r'   rX  rX    s        S S S S S S S  26/3,0/3&*?
 ?
<?
 !.?
 u|,	?

 $D>?
 'tn?
 d^?
 
u00	1?
 ?
 ?
 ^?
 ?
 ?
 ?
 ?
r)   rX  c                        e Zd ZU eed<   def fdZdej        fdZd Z	e
	 	 	 	 ddej        deej                 d	ee         d
ee         dee         deeef         fd            Z xZS )Owlv2TextModelr   c                     t                                          |           t          |          | _        |                                  d S rR   )r   r   rX  
text_model	post_initr   s     r'   r   zOwlv2TextModel.__init__  s@       .v66r)   r   c                 $    | j         j        j        S rR   rf  r   r   rH   s    r'   get_input_embeddingsz#Owlv2TextModel.get_input_embeddings
  s    )99r)   c                 (    || j         j        _        d S rR   ri  )rB   values     r'   set_input_embeddingsz#Owlv2TextModel.set_input_embeddings  s    5:"222r)   Nr   r   r   rJ  rK  c                 6    |                      |||||          S )a  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)

        Examples:
        ```python
        >>> from transformers import AutoProcessor, Owlv2TextModel

        >>> model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16")
        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
        >>> inputs = processor(
        ...     text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
        ... )
        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
        ```r   r   r   rJ  rK  )rf  )rB   r   r   r   rJ  rK  s         r'   r   zOwlv2TextModel.forward  s.    < )/!5#  
 
 	
r)   )NNNN)rJ   rK   rL   r   rO   r   r	   r@  rj  rm  r   r$   r   r   r   r   rF   r   r   r   r   s   @r'   rd  rd    s              :bi : : : :; ; ;  26,0/3&*#
 #
<#
 !.#
 $D>	#

 'tn#
 d^#
 
u00	1#
 #
 #
 ^#
 #
 #
 #
 #
r)   rd  c                        e Zd Zdef fdZe	 	 	 	 ddej        dee	         dee	         dee	         d	ee	         d
e
eef         fd            Z xZS )Owlv2VisionTransformerr   c                 :   t                                                       || _        t          |          | _        t          j        |j        |j                  | _	        t          |          | _        t          j        |j        |j                  | _        d S r  )r   r   r   r   r   r	   r  r   r  pre_layernormrB  rZ  post_layernormr   s     r'   r   zOwlv2VisionTransformer.__init__9  s~    /77\&*<&BWXXX#F++ l6+=6CXYYYr)   NFr   r   rJ  r   rK  r   c                    ||n| j         j        }||n| j         j        }||n| j         j        }| j        j        j        j        }|                    |          }|                     ||          }| 	                    |          }| 
                    ||||          }|d         }	|	d d dd d f         }
|                     |
          }
|s|	|
f|dd          z   S t          |	|
|j        |j                  S )N)r   )r   r   rJ  rK  r   r   r]  )r   r   rJ  rQ  r   r   r   rT   r   rs  rZ  rt  r   r   rP  )rB   r   r   rJ  r   rK  expected_input_dtyper   ra  rO  rb  s              r'   r   zOwlv2VisionTransformer.forwardB  s?    2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B]  $>EK#';<<Oghh**=99,,'/!5#	 ' 
 
 ,A.)!!!Q'2++M:: 	L%}58KKK)/')7&1	
 
 
 	
r)   )NNFN)rJ   rK   rL   r   r   r   r$   rN   r   r   r   rF   r   r   r   r   s   @r'   rq  rq  8  s        Z0 Z Z Z Z Z Z  -1/338&*)
 )
')
 $D>)
 'tn	)

 #+4.)
 d^)
 
u00	1)
 )
 )
 ^)
 )
 )
 )
 )
r)   rq  c                        e Zd ZU eed<   dZdef fdZdej        fdZ	e
	 	 	 	 	 ddeej                 dee         d	ee         d
edee         deeef         fd            Z xZS )Owlv2VisionModelr   r   c                     t                                          |           t          |          | _        |                                  d S rR   )r   r   rq  vision_modelrg  r   s     r'   r   zOwlv2VisionModel.__init__t  sA       26::r)   r   c                 $    | j         j        j        S rR   )rz  r   r   rH   s    r'   rj  z%Owlv2VisionModel.get_input_embeddingsz  s     +;;r)   NFr   rJ  r   rK  c                 6    |                      |||||          S )a  
        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Owlv2VisionModel

        >>> model = Owlv2VisionModel.from_pretrained("google/owlv2-base-patch16")
        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> last_hidden_state = outputs.last_hidden_state
        >>> pooled_output = outputs.pooler_output  # pooled CLS states
        ```r   r   rJ  r   rK  )rz  )rB   r   r   rJ  r   rK  s         r'   r   zOwlv2VisionModel.forward}  s0    6   %/!5%=# ! 
 
 	
r)   NNNFN)rJ   rK   rL   r   rO   main_input_namer   r	   r@  rj  r   r   r$   rN   r   r   rF   r   r   r   r   s   @r'   rx  rx  p  s        $O0      <bi < < < <  59,0/3).&* 
  
u01 
 $D> 
 'tn	 

 #' 
 d^ 
 
u00	1 
  
  
 ^ 
  
  
  
  
r)   rx  c                       e Zd ZU eed<   def fdZ e            e	 ddej	        de
ej	                 dej        fd                        Z e            e	 dd	ej	        d
edej        fd                        Ze	 	 	 	 	 	 	 	 	 dde
ej                 d	e
ej                 de
ej	                 de
e         de
e         de
e         d
ede
e         de
e         deeef         fd            Z xZS )r/  r   c                    t                                          |           t          |j        t                    s%t          dt          |j                   d          t          |j        t                    s%t          dt          |j                   d          |j        }|j        }|j	        | _	        |j
        | _        |j
        | _        t          |          | _        t          |          | _        t#          j        | j        | j	        d          | _        t#          j        | j        | j	        d          | _        t#          j        t-          j        |j                            | _        |                                  d S )NzLconfig.text_config is expected to be of type Owlv2TextConfig but is of type .zPconfig.vision_config is expected to be of type Owlv2VisionConfig but is of type F)r   )r   r   r)  text_configr   	TypeErrortypevision_configr   projection_dimr   r1  r3  rX  rf  rq  rz  r	   r   r2  r0  r   r$   r   r6  r4  rg  )rB   r   r  r  r   s       r'   r   zOwlv2Model.__init__  sy      &,o>> 	0+,,0 0 0  
 &.0ABB 	2-..2 2 2  
 (,$3)5 - 9.{;;2=AA!#4+@$BU\a!b!b!b!y)<d>QX]^^^<V5R(S(STT 	r)   Nr   r   r   c                 h    |                      ||          }|                     |j                  }|S )a  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids)

        Returns:
            text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of [`Owlv2TextModel`].

        Examples:
        ```python
        >>> import torch
        >>> from transformers import AutoProcessor, Owlv2Model

        >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> inputs = processor(
        ...     text=[["a photo of a cat", "a photo of a dog"], ["photo of a astranaut"]], return_tensors="pt"
        ... )
        >>> with torch.inference_mode():
        ...     text_features = model.get_text_features(**inputs)
        ```)r   r   )rf  r0  r^  )rB   r   r   text_outputstext_featuress        r'   get_text_featureszOwlv2Model.get_text_features  s6    > 48??Ygu?3v3v,,\-GHHr)   Fr   r   c                 h    |                      ||          }|                     |j                  }|S )av  
        Returns:
            image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
            applying the projection layer to the pooled output of [`Owlv2VisionModel`].

        Examples:
        ```python
        >>> import torch
        >>> from transformers.image_utils import load_image
        >>> from transformers import AutoProcessor, Owlv2Model

        >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = load_image(url)

        >>> inputs = processor(images=image, return_tensors="pt")
        >>> with torch.inference_mode():
        ...     image_features = model.get_image_features(**inputs)
        ```r   r   )rz  r2  r^  )rB   r   r   vision_outputsimage_featuress        r'   get_image_featureszOwlv2Model.get_image_features  sC    8 6:5F5F%%= 6G 6
 6
 //0LMMr)   return_lossr   rJ  return_base_image_embedsrK  c
           	      2   ||n| j         j        }||n| j         j        }|	|	n| j         j        }	|                     |||||	          }
|                     |||||	          }|d         }|                     |          }|
d         }|                     |          }|t          j	        
                    |ddd          z  }|t          j	        
                    |ddd          z  }| j                                                            |j                  }t          j        ||                                          |z  }|                                }d}|rt#          |          }|}|	s||||||
f}||f|z   n|S t%          |||||||
	          S )
a4  
        return_loss (`bool`, *optional*):
            Whether or not to return the contrastive loss.
        return_base_image_embeds (`bool`, *optional*):
            Whether or not to return the base image embeddings.

        Examples:
        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Owlv2Model

        >>> model = Owlv2Model.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> inputs = processor(text=[["a photo of a cat", "a photo of a dog"]], images=image, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> logits_per_image = outputs.logits_per_image  # this is the image-text similarity score
        >>> probs = logits_per_image.softmax(dim=1)  # we can take the softmax to get the label probabilities
        ```Nr}  ro  r   r^   r   T)ordr   keepdim)r2   r3   r4   r5   r6   r7   r8   )r   r   rJ  rQ  rz  rf  r0  r2  r$   linalgnormr4  expr   r!   matmulr,   r/   r1   )rB   r   r   r   r  r   rJ  r   r  rK  r  r  r5   r6   text_embeds_normr4  r4   r3   r2   outputs                       r'   r   zOwlv2Model.forward  s   F 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B]**%/!5%=# + 
 
 )/!5# ' 
 
 #1o**;77%a(--l;; $el&7&7!QS]a&7&b&bb&):):;ASU_c):)d)dd &**,,//0CDD,'79I9IJJ[X*,,.. 	/o..D& 	F&lT`bpqF)-)9TGf$$vE-+#%* .
 
 
 	
r)   rR   r   )	NNNNNNFNN)rJ   rK   rL   r   rO   r   r   r   r$   r   r   rN   r  r   r  r   r   rF   r1   r   r   r   s   @r'   r/  r/    s         {      @ %$&& 26   <  !.  
		      ^ '& D %$&& */   l  #'  
		      ^ '& D  154815&*,0/3).37&*Z
 Z
E,-Z
 u01Z
 !.	Z

 d^Z
 $D>Z
 'tnZ
 #'Z
 #+4.Z
 d^Z
 
uk!	"Z
 Z
 Z
 ^Z
 Z
 Z
 Z
 Z
r)   r/  c                   N     e Zd Zddedef fdZdej        dej        fdZ	 xZ
S )	Owlv2BoxPredictionHead   r   out_dimc                 ,   t                                                       |j        j        }t	          j        ||          | _        t	          j        ||          | _        t	          j                    | _	        t	          j        ||          | _
        d S rR   )r   r   r  r   r	   r   dense0dense1GELUgeludense2)rB   r   r  r   r   s       r'   r   zOwlv2BoxPredictionHead.__init__n  sn    $0iu--iu--GII	iw//r)   r  r   c                     |                      |          }|                     |          }|                     |          }|                     |          }|                     |          }|S rR   )r  r  r  r  )rB   r  r  s      r'   r   zOwlv2BoxPredictionHead.forwardw  s\    ^,,6""V$$6""V$$r)   )r  )rJ   rK   rL   r   rZ   r   r$   r   rN   r   r   r   s   @r'   r  r  m  sw        0 0{ 0S 0 0 0 0 0 0el u7H        r)   r  c            	            e Zd Zdef fdZdej        deej                 deej                 de	ej                 fdZ
 xZS )Owlv2ClassPredictionHeadr   c                 l   t                                                       |j        j        }|j        j        | _        t          j        | j        |          | _        t          j        | j        d          | _	        t          j        | j        d          | _
        t          j                    | _        d S )Nr   )r   r   r  r   r  	query_dimr	   r   r  logit_shiftr4  ELUelu)rB   r   r  r   s      r'   r   z!Owlv2ClassPredictionHead.__init__  s    $0-9i889T^Q779T^Q77688r)   r6   query_embeds
query_maskr   c                     |                      |          }|L|j        }|j        d d         \  }}t          j        ||| j        f                              |          }||fS |t          j                            |dd          dz   z  }|t          j                            |dd          dz   z  }t          j	        d||          }| 
                    |          }	|                     |          }
|                     |
          dz   }
||	z   |
z  }|v|j        dk    rt          j        |d	          }t          j        |d
k    t          j        |j                  j        |          }|                    t          j                  }||fS )Nr^   r   T)r   r  gư>z...pd,...qd->...pqr   r   r   r   )r  r!   r   r$   zerosr  r   r  r  einsumr  r4  r  ndimr   wherefinforT   rb   rU   )rB   r6   r  r  image_class_embedsr!   r   r   pred_logitsr  r4  s              r'   r   z Owlv2ClassPredictionHead.forward  s    "[[66'.F&8&>rr&B#J+z;&OPPSSTZ[[K!344 05<3D3DEW]_im3D3n3nqu3uv#u|'8'82W['8'\'\_c'cd l#79K\ZZ &&|44&&|44hh{++a/"[0K?!"""_ZR@@@
+jAou{;CT7U7U7Y[fggK%..77K/00r)   )rJ   rK   rL   r   r   r$   rN   r   r   rF   r   r   r   s   @r'   r  r    s        	{ 	 	 	 	 	 	!1'!1 u01!1 U\*	!1
 
u 	!!1 !1 !1 !1 !1 !1 !1 !1r)   r  c                       e Zd ZU eed<   def fdZedededej	        fd            Z
dej        dej        fdZ ed	
          	 d#dededeej                 dej	        fd            Z	 d$dej        dej        dedej        fdZ	 	 d%dej        deej                 deej	                 deej                 fdZ	 	 	 d&dej	        dej        dej	        dee         dee         dedeej                 fdZ	 	 	 d&dej        dee         dee         dedeej                 f
dZ	 d$dej        dej        dedej        fdZe	 	 	 	 	 d'dej        deej                 dee         dee         ded ee         defd!            Ze	 	 	 	 	 d'dej	        dej        deej	                 dee         dee         ded ee         defd"            Z xZS )(Owlv2ForObjectDetectionr   c                    t                                          |           t          |          | _        t	          |          | _        t          |          | _        t          |d          | _        t          j
        |j        j        |j        j                  | _        t          j                    | _        || _        | j        j        j        | j        j        j        z  | _        | j        j        j        | j        j        j        z  | _        |                     | j        | j                  | _        |                                  d S )Nr   )r  r  )r   r   r/  r!  r  
class_headr  box_headobjectness_headr	   r  r  r   r  
layer_normSigmoidsigmoidr   r   r   num_patches_heightnum_patches_widthcompute_box_biasbox_biasrg  r   s     r'   r   z Owlv2ForObjectDetection.__init__  s      ''
26::.v665faHHH,v';'GVMaMpqqqz||"&+";"F$+JcJn"n!%!:!EIbIm!m--d.EtG]^^ 	r)   r  r  r   c                 f   t          j        d|dz   t           j                  }t          j        d| dz   t           j                  }t          j        ||d          \  }}t          j        ||fd          }|dxx         |z  cc<   |dxx         | z  cc<   |                    dd	          }|S )
Nr   )rT   xy)indexingr   r   .r   .r   r^   )r$   r%   rU   meshgridstackr   )r  r  x_coordinatesy_coordinatesxxyybox_coordinatess          r'   !normalize_grid_corner_coordinatesz9Owlv2ForObjectDetection.normalize_grid_corner_coordinates  s     Q(9A(=U]SSSQ(:Q(>emTTT}tLLLB  +r2hB777#44#55 *..r155r)   r  c                 h    |                                 }|                     |          }|d         }|S )a#  Predicts the probability that each image feature token is an object.

        Args:
            image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)):
                Features extracted from the image.
        Returns:
            Objectness scores.
        r  )detachr  )rB   r  r{   s      r'   objectness_predictorz,Owlv2ForObjectDetection.objectness_predictor  s:     (..00 00@@-f5  r)   r^   )maxsizeNfeature_mapc                    |t          d          |                     ||          }t          j        |dd          }t          j        |dz             t          j        | dz             z
  }t          j        |d          }|dxx         |z  cc<   |dxx         |z  cc<   t          j        |dz             t          j        | dz             z
  }t          j        ||gd          }|S )	NzOfeature_map has been deprecated as an input. Please pass in num_patches insteadr$  r'  g-C6?r  r  r   r   )rr   r  r$   cliploglog1p	full_liker   )	rB   r  r  r  r  box_coord_biasbox_sizebox_size_biasr  s	            r'   r  z(Owlv2ForObjectDetection.compute_box_bias  s   
 "nooo@@ASUfgg*_c3?? ?T#9::U[/IY\`I`=a=aa ?>377--..	(T/22U[(TAQ5R5RR 9nm<"EEEr)   Fimage_featsr   c                     |                      |          }|r#|j        \  }}}}|                     ||          }n| j        }|                    |j                  }||z  }|                     |          }|S )a  
        Args:
            image_feats:
                Features extracted from the image, returned by the `image_text_embedder` method.
            feature_map:
                A spatial re-arrangement of image_features, also returned by the `image_text_embedder` method.
            interpolate_pos_encoding:
                Whether to interpolate the pre-trained position encodings.
        Returns:
            pred_boxes:
                List of predicted boxes (cxcywh normalized to 0, 1) nested within a dictionary.
        )r  r   r  r  r   r!   r  )	rB   r  r  r   r|   r   r  r  r  s	            r'   box_predictorz%Owlv2ForObjectDetection.box_predictor  s    & ]];//
 $ 	%:E:K7A!#4a,,-?ARSSHH}H;;{122h
\\*--
r)   r  r  c                 >    |                      |||          \  }}||fS )a8  
        Args:
            image_feats:
                Features extracted from the `image_text_embedder`.
            query_embeds:
                Text query embeddings.
            query_mask:
                Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
        )r  )rB   r  r  r  r  r  s         r'   class_predictorz'Owlv2ForObjectDetection.class_predictor!  s,     -1OOKWa,b,b)(/00r)   r   r   r   r   rJ  c           	      T   |                      ||||||d          }|r5|j        \  }}}	}
|	| j        j        j        z  }|
| j        j        j        z  }n| j        }| j        }|j        d         }| j         j        	                    |          }t          j        |d d d dd d f         |d d d df         j                  }|d d dd d d f         |z  }|                     |          }|j        d         |||j        d         f}|                    |          }|d         }|||fS )NT)r   r   r   r   rJ  r   rK  r   r   r   )r!  r   r   r  r   r  r  r8   rz  rt  r$   broadcast_tor  r   )rB   r   r   r   r   rJ  r   r  r   r   r   r  r  rO  r6   class_token_outnew_sizer5   s                     r'   image_text_embedderz+Owlv2ForObjectDetection.image_text_embedder5  sz    **%)/!5%=  
 
 $ 	7"."4Aq&%!'4;+D+O!O %)B)M M!%!8 $ 6 $7:z.==>OPP  ,\!!!RaR(-C\RSRSRSUXVXUXRXEYE_`` $AAAqrr111H-?|44 q!r"	
 $++H55bk\733r)   c                 :   | j                             ||d          }|r5|j        \  }}}}|| j        j        j        z  }	|| j        j        j        z  }
n| j        }	| j        }
|d         }| j         j                            |          }t          j
        |d d d dd d f         |d d d df         j                  }|d d dd d d f         |z  }|                     |          }|j        d         |	|
|j        d         f}|                    |          }||fS )NT)r   r   rK  r   r   r   )r!  rz  r   r   r  r   r  r  rt  r$   r  r  r   )rB   r   r   rJ  r   r  r   r   r   r  r  rO  r6   r  r  s                  r'   image_embedderz&Owlv2ForObjectDetection.image_embedderi  sd    00%@Xfj 1 
 
 $ 	7"."4Aq&%!'4;+D+O!O %)B)M M!%!8 $ 6 +1-z.==>OPP  ,\!!!RaR(-C\RSRSRSUXVXUXRXEYE_`` $AAAqrr111H-?|44 q!r"	
 $++H55n--r)   query_image_featuresquery_feature_mapc                    |                      |          \  }}|                     |||          }t          |          }g }g }	|j        }
t	          |j        d                   D ]Q}t          j        g dg|
          }||         }t          ||          \  }}t          j	        |d         dk              rt          ||          }t          j        |          dz  }|d         |k                                    }|                                r||         |                    d                   }t          j        ||         d          }t          j        d||          }|t          j        |                   }|                    ||         |                    |	                    |           S|r)t          j        |          }t          j        |	          }nd	\  }}|||fS )
Nr   )r   r   r   r   r    r$  g?r   )axiszd,id->iNN)r  r  r   r!   rG  r   r$   r   ro   rq   rv   rc   nonzeronumelsqueezer%  r  argminappendr  )rB   r  r  r   r   r}   r|   pred_boxes_as_cornersbest_class_embedsbest_box_indicespred_boxes_deviceieach_query_boxeach_query_pred_boxesiousiou_thresholdselected_indsselected_embeddingsmean_embedsmean_simbest_box_indr  box_indicess                          r'   embed_image_queryz)Owlv2ForObjectDetection.embed_image_query  s    ../CDD<''(<>OQijj
 8 D D 18+1!455 	6 	6A"\<<<.ARSSSN$9!$<!n.CDDGD! yaC(( R*>;PQQ "IdOOc1M!!W5>>@@M""$$ 6&21om6K6KA6N6N&O##jaqAAA <	;@STT,U\(-C-CD!((a)FGGG ''555 	3 ;'899L+&677KK(2%L+[*44r)   query_pixel_valuesrK  c           
         ||n| j         j        }||n| j         j        }||n| j         j        }|                     ||          d         }|                     ||||          \  }}	|j        \  }
}}}t          j        ||
||z  |f          }|j        \  }
}}}t          j        ||
||z  |f          }|                     |||          \  }}}| 	                    ||          \  }}| 
                    |||          }|s6|||||||	                                f}t          d |D                       }|S t          ||||||d|	          S )a  
        query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Pixel values of query image(s) to be detected. Pass in one query image per target image.

        Examples:
        ```python
        >>> import requests
        >>> from PIL import Image
        >>> import torch
        >>> from transformers import AutoProcessor, Owlv2ForObjectDetection

        >>> processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> query_url = "http://images.cocodataset.org/val2017/000000001675.jpg"
        >>> query_image = Image.open(requests.get(query_url, stream=True).raw)
        >>> inputs = processor(images=image, query_images=query_image, return_tensors="pt")

        >>> # forward pass
        >>> with torch.no_grad():
        ...     outputs = model.image_guided_detection(**inputs)

        >>> target_sizes = torch.Tensor([image.size[::-1]])

        >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
        >>> results = processor.post_process_image_guided_detection(
        ...     outputs=outputs, threshold=0.9, nms_threshold=0.3, target_sizes=target_sizes
        ... )
        >>> i = 0  # Retrieve predictions for the first image
        >>> boxes, scores = results[i]["boxes"], results[i]["scores"]
        >>> for box, score in zip(boxes, scores):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(f"Detected similar object with confidence {round(score.item(), 3)} at location {box}")
        Detected similar object with confidence 0.938 at location [327.31, 54.94, 547.39, 268.06]
        Detected similar object with confidence 0.959 at location [5.78, 360.65, 619.12, 366.39]
        Detected similar object with confidence 0.902 at location [2.85, 360.01, 627.63, 380.8]
        Detected similar object with confidence 0.985 at location [176.98, -29.45, 672.69, 182.83]
        Detected similar object with confidence 1.0 at location [6.53, 14.35, 624.87, 470.82]
        Detected similar object with confidence 0.998 at location [579.98, 29.14, 615.49, 489.05]
        Detected similar object with confidence 0.985 at location [206.15, 10.53, 247.74, 466.01]
        Detected similar object with confidence 0.947 at location [18.62, 429.72, 646.5, 457.72]
        Detected similar object with confidence 0.996 at location [523.88, 20.69, 586.84, 483.18]
        Detected similar object with confidence 0.998 at location [3.39, 360.59, 617.29, 499.21]
        Detected similar object with confidence 0.969 at location [4.47, 449.05, 614.5, 474.76]
        Detected similar object with confidence 0.966 at location [31.44, 463.65, 654.66, 471.07]
        Detected similar object with confidence 0.924 at location [30.93, 468.07, 635.35, 475.39]
        ```Nr  r   )r   r   rJ  r   )r  r  c              3      K   | ]}||V  	d S rR   rP   r@   xs     r'   rC   zAOwlv2ForObjectDetection.image_guided_detection.<locals>.<genexpr>(  "      >>1>>r)   )r6   r   r   r   r   r}   r7   r8   )r   r   rJ  rK  r  r   r$   r   r  r  r  r>   rF   r   )rB   r   r  r   rJ  r   rK  r  r  r  r   r  r  
hidden_dimr  query_image_featsr  r  r   r  r}   r   r  s                          r'   image_guided_detectionz.Owlv2ForObjectDetection.image_guided_detection  s   v 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+BY !//+F^ 0 
 

 '+&9&9%/!5%=	 ': '
 '
#^ ITHYE
&(9:mK*>PSd>dfp1qrrHYH_E
&(9:!M
,>AR,RT^_
 
 <@;Q;Q02J<
 <
8&(8
 '+&:&:{am&:&n&n#l !..{KIabb 	!! ''))F >>f>>>>>FM4$0/-%" .	
 	
 	
 		
r)   c           
         ||n| j         j        }||n| j         j        }||n| j         j        }|                     ||||||          \  }}	}
|
j        }|
j        }|	j        \  }}}}t          j	        |	|||z  |f          }|j        d         |z  }|	                    |||j        d                   }|	                    |||j        d                   }|d         dk    }| 
                    |||          \  }}|                     |          }|                     ||	|          }|sI|||||	||                                |                                f}t          d |D                       }|S t          |	|||||||          S )a	  
        input_ids (`torch.LongTensor` of shape `(batch_size * num_max_text_queries, sequence_length)`, *optional*):
            Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
            [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input
            IDs?](../glossary#input-ids).
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the last hidden state. See `text_model_last_hidden_state` and
            `vision_model_last_hidden_state` under returned tensors for more detail.

        Examples:
        ```python
        >>> import requests
        >>> from PIL import Image
        >>> import torch

        >>> from transformers import Owlv2Processor, Owlv2ForObjectDetection

        >>> processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble")
        >>> model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16-ensemble")

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)
        >>> text_labels = [["a photo of a cat", "a photo of a dog"]]
        >>> inputs = processor(text=text_labels, images=image, return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> # Target image sizes (height, width) to rescale box predictions [batch_size, 2]
        >>> target_sizes = torch.tensor([(image.height, image.width)])
        >>> # Convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
        >>> results = processor.post_process_grounded_object_detection(
        ...     outputs=outputs, target_sizes=target_sizes, threshold=0.1, text_labels=text_labels
        ... )
        >>> # Retrieve predictions for the first image for the corresponding text queries
        >>> result = results[0]
        >>> boxes, scores, text_labels = result["boxes"], result["scores"], result["text_labels"]
        >>> for box, score, text_label in zip(boxes, scores, text_labels):
        ...     box = [round(i, 2) for i in box.tolist()]
        ...     print(f"Detected {text_label} with confidence {round(score.item(), 3)} at location {box}")
        Detected a photo of a cat with confidence 0.614 at location [341.67, 23.39, 642.32, 371.35]
        Detected a photo of a cat with confidence 0.665 at location [6.75, 51.96, 326.62, 473.13]
        ```N)r   r   r   r   rJ  r   r   r   r  c              3      K   | ]}||V  	d S rR   rP   r	  s     r'   rC   z2Owlv2ForObjectDetection.forward.<locals>.<genexpr>  r  r)   )r6   r5   r|   r   r{   r}   r7   r8   )r   r   rJ  rK  r  r7   r8   r   r$   r   r  r  r  r>   rF   ry   )rB   r   r   r   r   rJ  r   rK  r  r  r  r  r  r   r  r  r  r  max_text_queriesr  r  r}   r{   r|   r  s                            r'   r   zOwlv2ForObjectDetection.forward6  s   h 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+BY .2-E-E%)/!5%= .F .
 .
*k7 0 4HSHYE
&(9:mK*>PSd>dfp1qrr %?1-;#++J8H,J\]_J`aa %%j2BIOTVDWXX	v&*
 '+&:&:;V`&a&a#l !55kBB ''[BZ[[
 	!%%''''))	F >>f>>>>>FM)$$!/%* .	
 	
 	
 		
r)   rR   r   r  r  r~  )rJ   rK   rL   r   rO   r   staticmethodrZ   r$   r   r  rN   r  r   r   r  r   r  rF   r  r  r  r  r   r   r  ry   r   r   r   s   @r'   r  r    sI        {      $ c VY ^c^j    \ !53D !IZ ! ! ! ! Yq ko "%:=LTUZUfLg	   6 */	 & & #'	
 
	   J 59-1	1 1&1 u011 U\*	1
 
u 	!1 1 1 12 -1/3).14 14<14 '14 	14
 $D>14 'tn14 #'14 
u 	!14 14 14 14n -1/3).(. (.'(. $D>(. 'tn	(.
 #'(. 
u 	!(. (. (. (.^ */	*5 *5#/*5 !,*5 #'	*5
 
	*5 *5 *5 *5X  ;?,0/3).&*s
 s
's
 %U%67s
 $D>	s

 'tns
 #'s
 d^s
 
/s
 s
 s
 ^s
j 
 26,0/3).&*r
 r
<r
 'r
 !.	r

 $D>r
 'tnr
 #'r
 d^r
 
$r
 r
 r
 ^r
 r
 r
 r
 r
r)   r  )r/  r   rd  rx  r  )BrM   dataclassesr   	functoolsr   typingr   r   r   r$   r   r	   activationsr   modeling_attn_mask_utilsr   r   modeling_layersr   modeling_outputsr   r   modeling_utilsr   utilsr   r   r   r   r   r   configuration_owlv2r   r   r   transformers.image_transformsr   
get_loggerrJ   loggerr(   r/   r1   r[   r_   ro   rv   ry   r   r@  r   r   r   r  r  r   rB  rX  rd  rq  rx  r/  r  r  r  __all__rP   r)   r'   <module>r!     s*     ! ! ! ! ! !       ' ' ' ' ' ' ' ' ' '          ! ! ! ! ! ! d d d d d d d d 9 9 9 9 9 9 K K K K K K K K - - - - - -                Q P P P P P P P P P  GFFFFFF 
	H	%	%`U\ `el ` ` ` `
-5< -EL - - - - !
 !
 !
 !
 !
+ !
 !
  !
JGv G& G G G GEF Ev E E E E"  "' ' '0   
/
 /
 /
 /
 /
 /
 /
  /
d   *
 *
 *
 *
 *
K *
 *
  *
\J J J J JBI J J J\    ")   @h2 h2 h2 h2 h2RY h2 h2 h2X    ry    / / / / /2 / / /d ,) ,) ,) ,) ,)? ,) ,) ,)`M
 M
 M
 M
 M
29 M
 M
 M
bI
 I
 I
 I
 I
29 I
 I
 I
Z3
 3
 3
 3
 3
) 3
 3
 3
n4
 4
 4
 4
 4
RY 4
 4
 4
p.
 .
 .
 .
 .
+ .
 .
 .
b F
 F
 F
 F
 F
% F
 F
 F
T    RY   (-1 -1 -1 -1 -1ry -1 -1 -1`x
 x
 x
 x
 x
2 x
 x
 x
v r
q
qr)   