
     `ij              	          d Z ddlZddlmZ ddlmZmZ ddlZddlm	Z	 ddl
mZ ddlmZ dd	lmZ dd
lmZ ddlmZmZmZmZ ddlmZ  ej        e          Ze ed           G d de                                  Ze ed           G d de                                  Z G d de	j                  Z  G d de	j                  Z! G d de	j                  Z" G d de	j                  Z#d@dej$        d e%d!e&d"ej$        fd#Z' G d$ d%e	j                  Z( G d& d'e          Z) G d( d)e	j                  Z* G d* d+e	j+                  Z, G d, d-e	j                  Z- G d. d/e	j                  Z.e G d0 d1e                      Z/e G d2 d3e/                      Z0d4ej$        d5e1d"ej$        fd6Z2d4ej$        d7e1d8e1d"ej$        fd9Z3 G d: d;e	j                  Z4 ed<           G d= d>e/                      Z5g d?Z6dS )AzPyTorch SegGpt model.    N)	dataclass)OptionalUnion)nn)
functional   )ACT2FN)GradientCheckpointingLayer)PreTrainedModel)ModelOutputauto_docstringlogging	torch_int   )SegGptConfigz1
    Output type of [`SegGptEncoderOutput`].
    )custom_introc                       e Zd ZU dZej        ed<   dZee	ej                          ed<   dZ
ee	ej                          ed<   dZee	ej                          ed<   dS )SegGptEncoderOutputay  
    last_hidden_state (`torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`):
        Sequence of hidden-states at the output of the last layer of the model.
    hidden_states (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
        of shape `(batch_size, patch_height, patch_width, hidden_size)`.
    attentions (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`):
        Tuple of *torch.FloatTensor* (one for each layer) of shape
        `(batch_size, num_heads, seq_len, seq_len)`.
    intermediate_hidden_states (`tuple[torch.FloatTensor]`, *optional*, returned when `config.intermediate_hidden_state_indices` is set):
        Tuple of `torch.FloatTensor` of shape `(batch_size, patch_height, patch_width, hidden_size)`.
        Each element in the Tuple corresponds to the output of the layer specified in `config.intermediate_hidden_state_indices`.
        Additionally, each feature passes through a LayerNorm.
    last_hidden_stateNhidden_states
attentionsintermediate_hidden_states)__name__
__module____qualname____doc__torchFloatTensor__annotations__r   r   tupler   r        ~/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/transformers/models/seggpt/modeling_seggpt.pyr   r   #   s           ((((8<M8E%"345<<<59Ju012999EIu/@)A BIIIIIr"   r   z;
    Output type of [`SegGptImageSegmentationOutput`].
    c                       e Zd ZU dZdZeej                 ed<   dZ	eej                 ed<   dZ
eeej                          ed<   dZeeej                          ed<   dS )SegGptImageSegmentationOutputa  
    loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
        The loss value.
    pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
        The predicted masks.
    hidden_states (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_hidden_states=True`):
        Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
        of shape `(batch_size, patch_height, patch_width, hidden_size)`.
    attentions (`tuple[torch.FloatTensor]`, `optional`, returned when `config.output_attentions=True`):
        Tuple of `torch.FloatTensor` (one for each layer) of shape
        `(batch_size, num_heads, seq_len, seq_len)`.
    Nloss
pred_masksr   r   )r   r   r   r   r&   r   r   r   r   r'   r   r    r   r!   r"   r#   r%   r%   ?   s           )-D(5$
%,,,.2J*+2228<M8E%"345<<<59Ju01299999r"   r%   c                   (     e Zd ZdZ fdZd Z xZS )SegGptPatchEmbeddingsz
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    c                    t                                                       |j        |j        }}|j        |j        }}t          |t          j        j	                  r|n||f}t          |t          j        j	                  r|n||f}|d         |d         z  |d         |d         z  z  }|| _        || _        || _        || _
        t          j        ||||          | _        d S )Nr   r   )kernel_sizestride)super__init__
image_size
patch_sizenum_channelshidden_size
isinstancecollectionsabcIterablenum_patchesr   Conv2d
projection)selfconfigr/   r0   r1   r2   r7   	__class__s          r#   r.   zSegGptPatchEmbeddings.__init__a   s    !'!2F4EJ
$*$79Kk#-j+/:R#S#SqZZZdfpYq
#-j+/:R#S#SqZZZdfpYq
!!}
15*Q-:VW=:XY$$(&)L+:^hiiir"   c                 P   |j         \  }}}}|| j        k    rt          d          || j        d         k    s|| j        d         k    r2t          d| d| d| j        d          d| j        d          d	          |                     |                              ddd	d          }|S )
NzeMake sure that the channel dimension of the pixel values match with the one set in the configuration.r   r   zInput image size (*z) doesn't match model ().   r   )shaper1   
ValueErrorr/   r9   permute)r:   pixel_values
batch_sizer1   heightwidth
embeddingss          r#   forwardzSegGptPatchEmbeddings.forwardo   s    2>2D/
L&%4,,,w   T_Q'''5DOA4F+F+FwVwwewwDO\]L^wwaeapqraswww   __\22::1aAFF
r"   )r   r   r   r   r.   rI   __classcell__r<   s   @r#   r)   r)   Z   sV         j j j j j      r"   r)   c                        e Zd ZdZdeddf fdZdededej        fdZ		 	 dd	ej        d
ej        de
ej                 de
e         dej        f
dZ xZS )SegGptEmbeddingszX
    Construct the embeddings from patch, position embeddings for input and prompt.
    r;   returnNc                 8   t                                                       t          j        t	          j        ddd|j                            | _        t          j        t	          j        ddd|j                            | _        t          j        t	          j        ddd|j                            | _	        t          j        t	          j        ddd|j                            | _
        t          j        t	          j        ddd|j                            | _        t          |          | _        |j        |j        z  dz  dz   }t          j        t	          j        d||j                            | _        t          j        |j                  | _        d S )Nr   r@   )r-   r.   r   	Parameterr   zerosr2   
mask_tokensegment_token_inputsegment_token_prompttype_token_semantictype_token_instancer)   patch_embeddingspretrain_image_sizer0   randnposition_embeddingsDropouthidden_dropout_probdropout)r:   r;   num_positionsr<   s      r#   r.   zSegGptEmbeddings.__init__   s7   ,u{1aF<N'O'OPP#%<Aq!VEW0X0X#Y#Y $&LQ1fFX1Y1Y$Z$Z!#%<Aq!VEW0X0X#Y#Y #%<Aq!VEW0X0X#Y#Y  5f = =3v7HHQNQRR#%<A}fN`0a0a#b#b z&"<==r"   rF   rG   c                    | j         d d dd f         }|j        d         }t          |dz            }t          j                                        s||k    s||k    r^t          j        |                    d||d          	                    dddd          ||fdd	          }|	                    dddd          S |                    d||d          S )
Nr         ?r   r   r@   bicubicF)sizemodealign_corners)
rZ   rA   r   r   jit
is_tracingFinterpolatereshaperC   )r:   rF   rG   patch_pos_embedr7   pretrain_patch_sizes         r#   interpolate_pos_encodingz)SegGptEmbeddings.interpolate_pos_encoding   s    2111abb59%+A.'S(899 9!! 
	A%8F%B%BFY]bFbFbm''+>@SUWXX``abdeghjklle_#	  O #**1aA666"**1feR@@@r"   rD   prompt_pixel_valuesbool_masked_posembedding_typec                 X   |                      |          }|                      |          }|j        \  }}}	}
| j                            |||	d          }|                    d                              |                              d||	d          }|d|z
  z  ||z  z   }||nd}|                     ||	          }|| j        z   }|| j	        z   }||z   }||z   }|dk    r| j
        }n |dk    r| j        }nt          d|           ||z   }||z   }t          j        ||fd          }|S )Nra   r   instancesemanticzBEmbedding type should be either 'semantic' or 'instance', but got r   dim)rW   rA   rR   expand	unsqueezetype_asrj   rm   rS   rT   rU   rV   rB   r   cat)r:   rD   rn   ro   rp   input_embeddingsprompt_embeddingsrE   patch_heightpatch_width_rR   w	pos_embedtype_embeddingrH   s                   r#   rI   zSegGptEmbeddings.forward   s     00>> 112EFF3C3I0
L+q_++JkSUVV
%%b))11*==EEb,Xcefgg-Q7*q.H+9+E: 11,LL	 ,d.FF-0II ,i7-	9 Z''!5NNz))!5NNrbprrsss+n<->Y 02CD!LLL
r"   )NN)r   r   r   r   r   r.   intr   Tensorrm   r   
BoolTensorstrrI   rJ   rK   s   @r#   rM   rM   }   s         >| > > > > > > > As A3 A5< A A A A, 7;(,+ +l+ #\+ "%"23	+
 !+ 
+ + + + + + + +r"   rM   c                        e Zd ZdZ fdZdededej        dej        fdZdej        d	ej        d
ej        dej        de	eef         de	eef         dej        fdZ
ddej        dej        fdZ xZS )SegGptAttentionz=Multi-head Attention block with relative position embeddings.c                 R   t                                                       |j        |j        }}t	          |t
          j        j                  r|n||f}t	          |t
          j        j                  r|n||f}|d         |j        z  |d         |j        z  f}|j        |j	        z  }|j	        | _	        |dz  | _
        t          j        |j        |j        dz  |j                  | _        t          j        |j        |j                  | _        |j        | _        | j        r|t#          d          t          j        t'          j        d|d         z  dz
  |                    | _        t          j        t'          j        d|d         z  dz
  |                    | _        d S d S )Nr   r   g      r   biaszBInput size must be provided if using relative positional encoding.r@   )r-   r.   r/   r0   r3   r4   r5   r6   r2   num_attention_headsscaler   Linearqkv_biasqkvproj use_relative_position_embeddingsrB   rP   r   rQ   	rel_pos_h	rel_pos_w)r:   r;   r/   r0   
input_sizehead_dimr<   s         r#   r.   zSegGptAttention.__init__   s   !'!2F4EJ
#-j+/:R#S#SqZZZdfpYq
#-j+/:R#S#SqZZZdfpYq
 mv'88*Q-6K\:\]
%)CC#)#= t^
9V/1Ca1Gfo^^^If0&2DEE	060W-0 	X! !efff  \%+a*Q-6G!6KX*V*VWWDN\%+a*Q-6G!6KX*V*VWWDNNN	X 	Xr"   q_sizek_sizerel_posrN   c                 n   t          dt          ||          z  dz
            }t          j        |                    d|j        d         d                              ddd          |d          }|                    d|                              dd          }t          j        |          dddf         t          ||z  d          z  }t          j        |          dddf         t          ||z  d          z  }||z
  |dz
  t          ||z  d          z  z   }||	                                         S )	a  
        Get relative positional embeddings according to the relative positions of
            query and key sizes.

        Args:
            q_size (int):
                size of the query.
            k_size (int):
                size of key k.
            rel_pos (`torch.Tensor`):
                relative position embeddings (L, channel).

        Returns:
            Extracted positional embeddings according to relative positions.
        r@   r   r   ra   linear)rc   rd   N      ?)
r   maxrh   ri   rj   rA   rC   r   arangelong)	r:   r   r   r   max_rel_distrel_pos_resizedq_coordsk_coordsrelative_coordss	            r#   get_rel_poszSegGptAttention.get_rel_pos   s2     1s66222Q677-OOAw}Q/44<<Q1EE
 
 

 *11"lCCKKAqQQ <''403v3L3LL<''aaa03v3L3LL#h.6A:Vf_VYAZAZ2ZZ335566r"   attnqueryr   r   c                    |\  }}|\  }	}
|                      ||	|          }|                      ||
|          }|j        \  }}}|                    ||||          }t          j        d||          }t          j        d||          }|                    ||||	|
          }||dddddddddf         z   |dddddddddf         z   }|                    |||z  |	|
z            }|S )a  
        Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
        https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

        Args:
            attn (`torch.Tensor`):
                attention map.
            query (`torch.Tensor`):
                query q in the attention layer with shape (batch_size, query_height * query_width, channel).
            rel_pos_h (`torch.Tensor`):
                relative position embeddings (Lh, channel) for height axis.
            rel_pos_w (`torch.Tensor`):
                relative position embeddings (Lw, channel) for width axis.
            q_size (tuple):
                spatial sequence size of query q with (query_height, query_width).
            k_size (tuple):
                spatial sequence size of key k with (key_height, key_width).

        Returns:
            attn (`torch.Tensor`):
                attention map with added relative positional embeddings.
        zbhwc,hkc->bhwkzbhwc,wkc->bhwkN)r   rA   rj   r   einsum)r:   r   r   r   r   r   r   query_heightquery_width
key_height	key_widthrelative_position_heightrelative_position_widthrE   r~   ru   reshaped_queryrel_hrel_ws                      r#   add_decomposed_rel_posz&SegGptAttention.add_decomposed_rel_pos  s$   > %+!k &
I#'#3#3L*i#X#X "&"2"2;	9"U"U"[
Asz<cRR-~?WXX-~?VWW||Jk:yYYeAAAqqq!!!QQQ,--aaaAAAtQQQ6F0GG||J{(BJQZDZ[[r"   Fr   c           	         |j         \  }}}}|                     |                              |||z  d| j        d                              ddddd          }|                    d|| j        z  ||z  d                              d          \  }}	}
|| j        z  |	                    dd          z  }| j        r(| 	                    ||| j
        | j        ||f||f          }t          j        j                            |t          j        d                              |j                  }|rC|                    || j        ||z  d          }|                    || j        z  ||z  d          }nd }||
z                      || j        ||d          }|                    ddddd                              |||d          }|                     |          }||fS )	Nr   ra   r@   r   r      )dtyperu   )rA   r   rj   r   rC   unbindr   	transposer   r   r   r   r   r   r   softmaxfloat32tor   viewr   )r:   r   output_attentionsrE   rF   rG   r~   r   r   keyvalueattn_weightsattn_weights_reshapedattn_outputs                 r#   rI   zSegGptAttention.forward:  s   '4':$
FE1 HH]##WZ%D4LbQQWQ1a## 	  KK:8P+PRX[`R`bdeellmnoosE
*cmmB.C.CC0 	66eT^T^fe_W]_dVe L x*22<u}Z\2]]``afalmm 	)
 %1$5$5j$BZ\bej\jln$o$o!055j4C[6[]cfk]kmoppLL$(!#e+44ZAY[achjlmm!))!Q1a88@@VUZ\^__ii,,233r"   )F)r   r   r   r   r.   r   r   r   r   r    r   rI   rJ   rK   s   @r#   r   r      s       GGX X X X X07# 7s 7U\ 7el 7 7 7 7@+l+ |+ <	+
 <+ c3h+ c3h+ 
+ + + +Z#4 #4U\ #4u| #4 #4 #4 #4 #4 #4 #4 #4r"   r   c                   B     e Zd Z fdZdej        dej        fdZ xZS )	SegGptMlpc                    t                                                       t          j        |j        |j                  | _        t          j        |j        |j                  | _        t          |j	                 | _
        d S N)r-   r.   r   r   r2   mlp_dimlin1lin2r	   
hidden_actactr:   r;   r<   s     r#   r.   zSegGptMlp.__init__b  s\    If0&.AA	Ifnf.@AA	&+,r"   r   rN   c                     |                      |          }|                     |          }|                     |          }|S r   )r   r   r   r:   r   s     r#   rI   zSegGptMlp.forwardh  s;    		-00//		-00r"   )r   r   r   r.   r   r   rI   rJ   rK   s   @r#   r   r   a  s^        - - - - -U\ el        r"   r           Finput	drop_probtrainingrN   c                     |dk    s|s| S d|z
  }| j         d         fd| j        dz
  z  z   }|t          j        || j        | j                  z   }|                                 |                     |          |z  }|S )aF  
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
    however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
    layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
    argument.
    r   r   r   )r   r   device)rA   ndimr   randr   r   floor_div)r   r   r   	keep_probrA   random_tensoroutputs          r#   	drop_pathr   p  s     CxII[^
Q 77E
5EL Y Y YYMYYy!!M1FMr"   c                   j     e Zd ZdZd	dee         ddf fdZdej        dej        fdZ	de
fdZ xZS )
SegGptDropPathzXDrop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).Nr   rN   c                 V    t                                                       || _        d S r   )r-   r.   r   )r:   r   r<   s     r#   r.   zSegGptDropPath.__init__  s$    "r"   r   c                 8    t          || j        | j                  S r   )r   r   r   r   s     r#   rI   zSegGptDropPath.forward  s    FFFr"   c                     d| j          S )Nzp=)r   r:   s    r#   
extra_reprzSegGptDropPath.extra_repr  s    $DN$$$r"   r   )r   r   r   r   r   floatr.   r   r   rI   r   r   rJ   rK   s   @r#   r   r     s        bb# #(5/ #T # # # # # #GU\ Gel G G G G%C % % % % % % % %r"   r   c                        e Zd Zdededdf fdZ	 	 ddej        ded	e	d
e	de
eej        ej        f         eej                 f         f
dZ xZS )SegGptLayerr;   drop_path_raterN   Nc                    t                                                       t          |          | _        t	          |          | _        |dk    rt          |          nt          j                    | _	        t          j
        |j        |j                  | _        t          j
        |j        |j                  | _        d S )Nr   eps)r-   r.   r   	attentionr   mlpr   r   Identityr   	LayerNormr2   layer_norm_epslayernorm_beforelayernorm_after)r:   r;   r   r<   s      r#   r.   zSegGptLayer.__init__  s    (00V$$;IC;O;O777UWU`UbUb "V-?VEZ [ [ [!|F,>FDYZZZr"   Fr   ensemble_condfeature_ensembler   c                    |                      |                     |          |          }|d         }|dd          }|r|j        d         dz  |k    r|                    |j        d         dz  d          \  }}	|dk    ra|j        d         dz  }
|	                    d|
d          }	|	                    dd                              |	          }	 |	j        |j         }	n*|	                    dd                              |	          }	t          j        ||	gd          }| 	                    |          |z   }|}| 
                    |          }|                     |          }|| 	                    |          z   }|f|z   }|S )	N)r   r   r   r@   rt   ra   T)ru   keepdim)r   r   rA   splitrj   mean	expand_asr   ry   r   r   r   )r:   r   r   r   r   self_attention_outputsattention_outputoutputspromptinputsnum_promptsresiduals               r#   rI   zSegGptLayer.forward  s    "&!!-00/ "0 "
 "
 2!4(, 		B 0 6q 9Q >- O O-334D4J14MQR4RXY3ZZNFF!!.4Q71<;;;D99CCFKK'6D99CCFKK$y&&)9qAAA '788=H ,,];;// 4>>-#@#@@ "W,r"   )FF)r   r   r   r   r   r.   r   r   r   boolr   r    rI   rJ   rK   s   @r#   r   r     s        [| [U [t [ [ [ [ [ [ "'"'# #|# # 	#
  # 
uU\5</0%2EE	F# # # # # # # #r"   r   c                   l     e Zd Zdeddf fdZ	 	 	 	 ddej        ded	ed
ededee	e
f         fdZ xZS )SegGptEncoderr;   rN   Nc                 z   t                                                       | _        d t          j        dj        j        d          D             t          j        fdt          j                  D                       | _
        t          j        j        j                  | _        d| _        d S )Nc                 6    g | ]}|                                 S r!   )item).0xs     r#   
<listcomp>z*SegGptEncoder.__init__.<locals>.<listcomp>  s     rrrAqvvxxrrrr"   r   cpu)r   c                 <    g | ]}t          |                   S r!   )r   )r  ir;   dprs     r#   r	  z*SegGptEncoder.__init__.<locals>.<listcomp>  s'    $j$j$jQ[Q%@%@$j$j$jr"   r   F)r-   r.   r;   r   linspacer   num_hidden_layersr   
ModuleListrangelayersr   r2   r   	layernormgradient_checkpointing)r:   r;   r  r<   s    `@r#   r.   zSegGptEncoder.__init__  s    rr63H&Jbkp!q!q!qrrrm$j$j$j$j$j%PVPhJiJi$j$j$jkkf&8f>STTT&+###r"   FTr   r   r   output_hidden_statesreturn_dictc                 :   |rdnd }|rdnd }g }t          | j                  D ]\  }	}
|r||fz   }| j        j        |	k    rdnd} |
||||          }|d         }|	| j        j        k    r4|d |j        d         dz           ||j        d         dz  d          z   dz  }|	| j        j        v r(|                    |                     |                     |r||d         fz   }|r||fz   }|st          d ||||fD                       S t          ||||          S )Nr!   r@   r   r   r`   c              3      K   | ]}||V  	d S r   r!   )r  vs     r#   	<genexpr>z(SegGptEncoder.forward.<locals>.<genexpr>  s0        =  === r"   )r   r   r   r   )
	enumerater  r;   merge_indexrA   !intermediate_hidden_state_indicesappendr  r    r   )r:   r   r   r   r  r  all_hidden_statesall_self_attentionsr   r  layer_moduler   layer_outputss                r#   rI   zSegGptEncoder.forward  s    #7@BBD$5?bb4%'"(55 	P 	POA|# I$58H$H! "&!81!<!<AA!M(LGWYjkkM)!,MDK+++!"?M$7$:a$?"?@=Q^QdefQgklQlQnQnCoo! DKAAA*11$..2O2OPPP  P&9]1=M<O&O# 	E 1]4D D 	  '):<OQkl     
 #++*'A	
 
 
 	
r"   )FFFT)r   r   r   r   r.   r   r   r  r   r    r   rI   rJ   rK   s   @r#   r  r    s        ,| , , , , , , , "'"'%* 0
 0
|0
 0
  	0

 #0
 0
 
u))	*0
 0
 0
 0
 0
 0
 0
 0
r"   r  c                   R     e Zd ZdZddd fd
Zdej        dej        f fdZ xZS )	SegGptLayerNormaA  LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
    width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
    gư>channels_last)r   data_formatc                z     t                      j        |fd|i| |dvrt          d|           || _        d S )Nr   )r%  channels_firstzUnsupported data format: )r-   r.   NotImplementedErrorr&  )r:   normalized_shaper   r&  kwargsr<   s        r#   r.   zSegGptLayerNorm.__init__  sY    )==s=f===AAA%&O+&O&OPPP&r"   featuresrN   c                    | j         dk    rR|                    dddd          }t                                          |          }|                    dddd          }n!t                                          |          }|S )z
        Args:
            features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
        r(  r   r@   r   r   )r&  rC   r-   rI   )r:   r,  r<   s     r#   rI   zSegGptLayerNorm.forward  sw    
 ///''1a33Hwwx00H''1a33HHwwx00Hr"   )	r   r   r   r   r.   r   r   rI   rJ   rK   s   @r#   r$  r$    s         
 15/ ' ' ' ' ' ' '           r"   r$  c                   4     e Zd Z fdZdej        fdZ xZS )SegGptDecoderHeadc                 J   t                                                       t          j        |j        |j        dd          | _        t          |j        |j        d          | _        t          |j
                 | _        t          j        |j        ddd          | _        d S )Nr   r   )r+   paddingr(  )r*  r   r&  T)r+   r   )r-   r.   r   r8   decoder_hidden_sizeconvr$  r   r  r	   r   act_fctheadr   s     r#   r.   zSegGptDecoderHead.__init__  s    I&&	
 
 
	 )#7V=R`p
 
 
 f/0If8!QUVVV			r"   r   c                     |                      |          }|                     |          }|                     |          }|                     |          }|S r   )r3  r  r4  r5  r   s     r#   rI   zSegGptDecoderHead.forward(  sL    		-00}55]33		-00r"   )r   r   r   r.   r   r   rI   rJ   rK   s   @r#   r/  r/    s[        W W W W WU%6        r"   r/  c                   X     e Zd Z fdZdej        dej        fdZdej        fdZ xZS )SegGptDecoderc                 :   t                                                       t          j        |j        t          |j                  z  |j        dz  |j        z  d          | _	        t          |          | _        |j        | _        |j        | _        || _        d S )Nr@   Tr   )r-   r.   r   r   r2   lenr  r0   r2  decoder_embedr/  decoder_predr;   r   s     r#   r.   zSegGptDecoder.__init__2  s    YV%M!N!NNq 6#==
 
 

 .f55 +#)#= r"   r   rN   c                     |j         \  }}}}|                    |||| j        | j        | j                  }|                    dddddd          }|                    |d|| j        z  || j        z  f          }|S )	Nr      r   r   r@   r   ra   rA   )rA   rj   r0   r2  rC   )r:   r   rE   r|   r}   r~   s         r#   _reshape_hidden_statesz$SegGptDecoder._reshape_hidden_states>  s    3@3F0
L+q%--k4?DOUYUm
 
 &--aAq!Q??%--r<$/#A;QUQ`C`a . 
 
 r"   c                     |                      |          }|                     |          }|                     |          }|S r   )r;  r@  r<  r   s     r#   rI   zSegGptDecoder.forwardJ  sA    **=9933MBB))-88r"   )	r   r   r   r.   r   r   r@  rI   rJ   rK   s   @r#   r8  r8  1  s~        
 
 
 
 

E4E 
%J[ 
 
 
 
U%6        r"   r8  c                   H    e Zd ZU eed<   dZdZdZddgZde	j
        dd	fd
Zd	S )SegGptPreTrainedModelr;   modelrD   TrM   r   modulerN   Nc                    | j         j        }t          |t          j        t          j        f          rt          j                            |j        j	        
                    t          j                  d|          
                    |j        j                  |j        _	        |j         |j        j	                                         dS dS t          |t          j        t"          f          r?|j        j	                                         |j        j	                            d           dS t          |t&                    rt          j                            |j        j	        
                    t          j                  d|          
                    |j        j                  |j        _	        t          j                            |j        j	        
                    t          j                  d|          
                    |j        j                  |j        _	        dS t          |t,                    rIt          j                            |j        j	        
                    t          j                  d|          
                    |j        j                  |j        _	        t          j        j                            |j        |           t          j        j                            |j        |           t          j        j                            |j        |           t          j        j                            |j        |           t          j        j                            |j        |           dS dS )zInitialize the weightsr   )r   stdNr   )rG  )r;   initializer_ranger3   r   r   r8   inittrunc_normal_weightdatar   r   r   r   r   zero_r   r$  fill_r   r   r   rM   rZ   normal_rR   rS   rT   rU   rV   )r:   rE  rG  s      r#   _init_weightsz#SegGptPreTrainedModel._init_weightsZ  s   k+fry")455 #	G "$!6!6v}7I7L7LU]7[7[bekn!6!o!o!r!r#" "FM {& &&((((( '& ?@@ 	GK""$$$M$$S)))))00 	G$&G$9$9 %((77 %: % % b!'((	 ! %'G$9$9 %((77 %: % % b!'((	 !!!  011 	G.0g.C.C*/225=AA /D / / b+122	 &+ HM!!&"3!===HM!!&"<#!FFFHM!!&"=3!GGGHM!!&"<#!FFFHM!!&"<#!FFFFF	G 	Gr"   )r   r   r   r   r   base_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modulesr   ModulerP  r!   r"   r#   rC  rC  R  sj         $O&*#+];&GBI &G$ &G &G &G &G &G &Gr"   rC  c                   F    e Zd Zdef fdZdefdZdeee	e         f         ddfdZ
e	 	 	 	 	 	 	 ddej        d	ej        d
ej        deej                 dee         dee         deej                 dee         dee         dee         deeef         fd            Z xZS )SegGptModelr;   c                     t                                          |           || _        t          |          | _        t          |          | _        |                                  d S r   )r-   r.   r;   rM   rH   r  encoder	post_initr   s     r#   r.   zSegGptModel.__init__  sX       *622$V,, 	r"   rN   c                     | j         j        S r   )rH   rW   r   s    r#   get_input_embeddingsz SegGptModel.get_input_embeddings  s    //r"   heads_to_pruneNc                     |                                 D ]/\  }}| j        j        |         j                            |           0dS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsrY  layerr   prune_heads)r:   r]  r`  headss       r#   _prune_headszSegGptModel._prune_heads  sU    
 +0022 	C 	CLE5Lu%/;;EBBBB	C 	Cr"   rD   rn   prompt_masksro   r   rp   labelsr   r  r  c                 X   ||n| j         j        }|	|	n| j         j        }	|
|
n| j         j        }
||nd}| j        j        j        j        j        }|	                    |          }|	                    |          }t          j        ||fd          }|t          j        ||fd          nt          j        ||fd          }||t                              d           || j        j        j        }t          j        |dz  t          j        |j                  }t          j        ||dz  z
  t          j        |j                  }t          j        ||g          }|                    d          }|                     ||||          }|                     ||||	|
	          }|S )
a
  
        prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See
            [`SegGptImageProcessor.__call__`] for details.
        prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for
            details.
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        feature_ensemble (`bool`, *optional*):
            Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble
            if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should
            be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image.
        embedding_type (`str`, *optional*):
            Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either
            instance or semantic.
        labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
            Ground truth mask for input images.

        Examples:

        ```python
        >>> from transformers import SegGptImageProcessor, SegGptModel
        >>> from PIL import Image
        >>> import requests

        >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
        >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
        >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"

        >>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
        >>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
        >>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")

        >>> checkpoint = "BAAI/seggpt-vit-large"
        >>> model = SegGptModel.from_pretrained(checkpoint)
        >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)

        >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")

        >>> outputs = model(**inputs)
        >>> list(outputs.last_hidden_state.shape)
        [1, 56, 28, 1024]
        ```
        NFr@   rt   zLabels were provided, but bool_masked_pos were not. It will be set to default value. If you're training the model, make sure to provide a bool_masked_pos.r   r   )rp   ro   )r   r   r  r  )r;   r   r  use_return_dictrH   rW   r9   rK  r   r   r   ry   loggerwarning_oncer7   rQ   r  r   onesrw   rY  )r:   rD   rn   rd  ro   r   rp   re  r   r  r  expected_dtyper7   bool_masked_pos_zerosbool_masked_pos_onesembedding_outputencoder_outputss                    r#   rI   zSegGptModel.forward  s   v 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B]/?/K++QV9DKQ#~66144^DD y"5|!D!LLL ~ I|\2::::L&1q999 	 "v'9 m   "/:FK$)Kq0@
[g[n$o$o$o!#(:kQ..ejI\$ $ $  $i)>@T(UVVO-77::O??-n^m + 
 
 ,,-/!5# ' 
 
 r"   NNNNNNN)r   r   r   r   r.   r)   r\  dictr   listrc  r   r   r   r   r   r  r   r   r   r    r   rI   rJ   rK   s   @r#   rW  rW    s       |      0&; 0 0 0 0C4T#Y+? CD C C C C  7;+/(,.2,0/3&*k klk #\k l	k
 "%"23k #4.k !k *+k $D>k 'tnk d^k 
u))	*k k k ^k k k k kr"   rW  tensorr0   c                     | j         \  }}}}||z  }||z  }|                     ||||||f          } |                     dddddd          } |                     |||z  |dz  dz  f          } | S )Nr?  r   r@   r   r   r>  r   )rA   rj   rC   )rs  r0   rE   r1   rF   rG   r|   r}   s           r#   patchifyru  	  s    .4l+JfeZ'L:%K^^:|\:Wbdn"o^ppF^^Aq!Q1--F^^:|k/I:WX=[\K\"]^^^FMr"   r|   r}   c           	      |   | j         d         }t          | j         d         dz  dz            }||z  | j         d         k    r$t          d| j         d          d| d| d	          |                     |||||df
          } |                     dddddd          } |                     |d||z  ||z  f
          } | S )Nr   ra   r   r`   r   zNumber of patches z does not match patch height (z) and width (r?   r?  r>  r@   r   )rA   r   rB   rj   rC   )rs  r|   r}   rE   r0   s        r#   
unpatchifyrw    s    aJfl2&*s233Jk!V\!_44zazzP\zzkvzzz
 
 	
 ^^:|[*V`bc"d^eeF^^Aq!Q1--F^^:q,2K[[eMe"f^ggFMr"   c                   ^     e Zd Z fdZdej        dej        dej        dej        fdZ xZS )
SegGptLossc                 x    t                                                       |j        | _        |j        | _        d S r   )r-   r.   betar0   r   s     r#   r.   zSegGptLoss.__init__%  s0    K	 +r"   rd  r'   re  ro   c                    t          j        ||fd          }|dddddf                             dd| j        dz  dz            }t	          ||j        d         | j        z  |j        d         | j        z            }t          j        ||d| j                  }||z  	                                |	                                z  }|S )aN  Computes the L1 loss between the predicted masks and the ground truth masks.

        Args:
            prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Pixel values from mask prompt.

            pred_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, 2*height, width)`):
                Predicted masks.

            labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
                Ground truth mask for input images.

            bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
                Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).

        Returns:
            `torch.FloatTensor`: The mean L1 loss between the predicted masks and the ground truth masks.
        r@   rt   Nr   r   none)	reductionr{  )
r   ry   repeatr0   rw  rA   rh   smooth_l1_lossr{  sum)r:   rd  r'   re  ro   ground_truthmaskr&   s           r#   rI   zSegGptLoss.forward*  s    2 y,!7Q???qqq!!!Tz*11!Q8JQ8NOO$ 21 5 H,J\]^J_cgcrJrss
LFQUQZ[[[t  ""TXXZZ/r"   )	r   r   r   r.   r   r   r   rI   rJ   rK   s   @r#   ry  ry  $  s~        , , , , ,
!'! %! !	!
 )! ! ! ! ! ! ! !r"   ry  zM
    SegGpt model with a decoder on top for one-shot image segmentation.
    c                       e Zd Zdef fdZe	 	 	 	 	 	 	 ddej        dej        dej        deej	                 dee
         d	ee         d
eej                 dee
         dee
         dee
         deeef         fd            Z xZS )SegGptForImageSegmentationr;   c                     t                                          |           || _        t          |          | _        t          |          | _        |                                  d S r   )r-   r.   r;   rW  rD  r8  decoderrZ  r   s     r#   r.   z#SegGptForImageSegmentation.__init__T  sX        ((
$V,, 	r"   NrD   rn   rd  ro   r   rp   re  r   r  r  rN   c                 0   ||n| j         j        }|	|	n| j         j        }	|
|
n| j         j        }
|| j        j        j        j        }t          j	        |dz  t          j
        |j                  }t          j        ||dz  z
  t          j
        |j                  }t          j        ||g          }|                    d          }|                     |||||||||	|

  
        }|
r|j        n|d         }t          j        |d          }|                     |          }d}|"t#          | j                   } |||||          }|
s/|f}|	r||d         fz   }|r|	rdnd}|||         fz   }||f|z   }|S t%          |||j        |j        	          S )
aY  
        prompt_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Prompt pixel values. Prompt pixel values can be obtained using [`AutoImageProcessor`]. See
            [`SegGptImageProcessor.__call__`] for details.
        prompt_masks (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
            Prompt mask. Prompt mask can be obtained using [`AutoImageProcessor`]. See [`SegGptImageProcessor.__call__`] for
            details.
        bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
            Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
        feature_ensemble (`bool`, *optional*):
            Boolean indicating whether to use feature ensemble or not. If `True`, the model will use feature ensemble
            if we have at least two prompts. If `False`, the model will not use feature ensemble. This argument should
            be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image.
        embedding_type (`str`, *optional*):
            Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either
            instance or semantic.
        labels (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, `optional`):
            Ground truth mask for input images.

        Examples:

        ```python
        >>> from transformers import SegGptImageProcessor, SegGptForImageSegmentation
        >>> from PIL import Image
        >>> import requests

        >>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
        >>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
        >>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"

        >>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
        >>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
        >>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")

        >>> checkpoint = "BAAI/seggpt-vit-large"
        >>> model = SegGptForImageSegmentation.from_pretrained(checkpoint)
        >>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)

        >>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")
        >>> outputs = model(**inputs)
        >>> result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[(image_input.height, image_input.width)])[0]
        >>> print(list(result.shape))
        [170, 297]
        ```
        Nr@   r   r   )
rD   rn   rd  ro   r   rp   re  r   r  r  ra   rt   r   )r&   r'   r   r   )r;   r   r  rg  rD  rH   rW   r7   r   rQ   r  r   rj  ry   rw   r   r  ry  r%   r   r   )r:   rD   rn   rd  ro   r   rp   re  r   r  r  r7   rl  rm  r   r   r'   r&   loss_fnr   idxs                        r#   rI   z"SegGptForImageSegmentation.forward^  s   v 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B]"*/@LK$)Kq0@
[g[n$o$o$o!#(:kQ..ejI\$ $ $  $i)>@T(UVVO-77::O**% 3%+-)/!5#  
 
 LW%gW%G%G\cdf\g"%*Y/Ir%R%R%R"\\"<==
 --G7<V_MMD 	 ]F# 071:-/  2/6aaQ73</16)M,!!/)	
 
 
 	
r"   rp  )r   r   r   r   r.   r   r   r   r   r   r  r   r   r   r    r%   rI   rJ   rK   s   @r#   r  r  N  s5       |        7;+/(,.2,0/3&*q
 q
lq
 #\q
 l	q

 "%"23q
 #4.q
 !q
 *+q
 $D>q
 'tnq
 d^q
 
u33	4q
 q
 q
 ^q
 q
 q
 q
 q
r"   r  )rW  rC  r  )r   F)7r   collections.abcr4   dataclassesr   typingr   r   r   r   torch.nnr   rh   activationsr	   modeling_layersr
   modeling_utilsr   utilsr   r   r   r   configuration_seggptr   
get_loggerr   rh  r   r%   rU  r)   rM   r   r   r   r   r  r   r   r   r  r   r$  r/  r8  rC  rW  r   ru  rw  ry  r  __all__r!   r"   r#   <module>r     s*         ! ! ! ! ! ! " " " " " " " "        $ $ $ $ $ $ ! ! ! ! ! ! 9 9 9 9 9 9 - - - - - - D D D D D D D D D D D D . . . . . . 
	H	%	%   
J J J J J+ J J  J,   
: : : : :K : :  :*         BI      FR R R R Rry R R RjK4 K4 K4 K4 K4bi K4 K4 K4^    	    U\ e T V[Vb    *% % % % %RY % % %, , , , ,, , , ,^9
 9
 9
 9
 9
BI 9
 9
 9
z    bl   4    	   0    BI   B -G -G -G -G -GO -G -G -G` B B B B B' B B BJ	U\ 	s 	u| 	 	 	 	u| 3 S U\    ' ' ' ' ' ' ' 'T   
}
 }
 }
 }
 }
!6 }
 }
 
}
@ Q
P
Pr"   