
     `i                    J   d Z ddlZddlZddlmZ ddlmZmZmZ ddl	Z	ddl	m
Z
 ddlmZmZmZ ddlmZmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlmZmZmZmZmZm Z m!Z!m"Z" ddl#m$Z$ ddl%m&Z&m'Z'm(Z( ddl)m*Z*m+Z+m,Z, ddl-m.Z. ddl/m0Z0  e,j1        e2          Z3dJdZ4 G d de
j5                  Z6 G d de
j5                  Z7 G d de
j5                  Z8de7iZ9 G d de
j5                  Z: G d de
j5                  Z; G d  d!e
j5                  Z< G d" d#e          Z= G d$ d%e
j5                  Z> G d& d'e
j5                  Z? G d( d)e
j5                  Z@e+ G d* d+e$                      ZAe e+d,-           G d. d/e*                                  ZBe+ G d0 d1eA                      ZC G d2 d3e
j5                  ZD G d4 d5e
j5                  ZE e+d6-           G d7 d8eA                      ZF e+d9-           G d: d;eA                      ZG e+d<-           G d= d>eA                      ZH e+d?-           G d@ dAeA                      ZIe+ G dB dCeA                      ZJe+ G dD dEeA                      ZK e+dF-           G dG dHeAe                      ZLg dIZMdS )KzPyTorch ELECTRA model.    N)	dataclass)CallableOptionalUnion)nn)BCEWithLogitsLossCrossEntropyLossMSELoss   )ACT2FNget_activation)CacheDynamicCacheEncoderDecoderCache)GenerationMixin)GradientCheckpointingLayer)"BaseModelOutputWithCrossAttentions)BaseModelOutputWithPastAndCrossAttentions!CausalLMOutputWithCrossAttentionsMaskedLMOutputMultipleChoiceModelOutputQuestionAnsweringModelOutputSequenceClassifierOutputTokenClassifierOutput)PreTrainedModel)apply_chunking_to_forward find_pruneable_heads_and_indicesprune_linear_layer)ModelOutputauto_docstringlogging)deprecate_kwarg   )ElectraConfigdiscriminatorc                    	 ddl }ddl}ddl}n)# t          $ r t                              d            w xY wt          j                            |          }t          	                    d|            |j
                            |          }g }	g }
|D ]j\  }}t          	                    d| d|            |j
                            ||          }|	                    |           |
                    |           kt          |	|
          D ]\  }}|}	 t          | t                     r|                    dd          }|d	k    r,|                    d
d          }|                    dd
          }|                    dd          }|                    dd          }|                    d          }t'          d |D                       rt          	                    d|            | }|D ]}|                    d|          r|                    d|          }n|g}|d         dk    s|d         dk    rt+          |d          }ny|d         dk    s|d         dk    rt+          |d          }nP|d         dk    rt+          |d          }n3|d         dk    rt+          |d          }nt+          ||d                   }t-          |          dk    rt/          |d                    }||         }|                    d!          rt+          |d          }n|dk    r|                    |          }	 |j        |j        k    r t7          d"|j         d#|j         d$          n/# t6          $ r"}|xj        |j        |j        fz  c_         d}~ww xY wt;          d%| |           t=          j        |          |_         # tB          $ r}t;          d| ||           Y d}~d}~ww xY w| S )&z'Load tf checkpoints in a pytorch model.r   NzLoading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see https://www.tensorflow.org/install/ for installation instructions.z&Converting TensorFlow checkpoint from zLoading TF weight z with shape zelectra/embeddings/zgenerator/embeddings/	generatorzelectra/zdiscriminator/z
generator/dense_1dense_predictionz!generator_predictions/output_biaszgenerator_lm_head/bias/c              3      K   | ]}|d v V  	dS ))global_steptemperatureN ).0ns     /home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/transformers/models/electra/modeling_electra.py	<genexpr>z-load_tf_weights_in_electra.<locals>.<genexpr>\   s(      EE1166EEEEEE    z	Skipping z[A-Za-z]+_\d+z_(\d+)kernelgammaweightoutput_biasbetabiasoutput_weightssquad
classifier   r#   _embeddingszPointer shape z and array shape z mismatchedzInitialize PyTorch weight )"renumpy
tensorflowImportErrorloggererrorospathabspathinfotrainlist_variablesload_variableappendzip
isinstanceElectraForMaskedLMreplacesplitany	fullmatchgetattrlenintendswith	transposeshape
ValueErrorargsprinttorch
from_numpydataAttributeError)modelconfigtf_checkpoint_pathdiscriminator_or_generatorr?   nptftf_path	init_varsnamesarraysnamerY   arrayoriginal_namepointerm_namescope_namesnumes                       r1   load_tf_weights_in_electrars   2   s   
			   Q	
 	
 	
 	 goo011G
KKBBBCCC''00IEF   eBBB5BBCCC&&w55Te5&)) 6 6e!3	%!344 T||$9;RSS)[88||J0@AA||L*==<<	+=>>D<< CE]^^D::c??D EEEEEEE 777888G + +<< 0&99 +"$((9f"="=KK#)(Kq>X--Q71J1J%gx88GG ^}44A&8P8P%gv66GG ^'777%gx88GG ^w..%g|<<GG%g{1~>>G{##q((k!n--C%clG}-- ,!'8448##U++=EK//$%ngm%n%nV[Va%n%n%nooo 0   7=%+66 5t55}EEE +E22GLL 	 	 	-m--tQ777HHHH	 LsL    &5CO-EO:0M+*O+
N5NN/O
O1O,,O1c                        e Zd ZdZ fdZ	 	 	 	 	 ddeej                 deej                 deej                 deej                 d	e	d
ej
        fdZ xZS )ElectraEmbeddingszGConstruct the embeddings from word, position and token_type embeddings.c                    t                                                       t          j        |j        |j        |j                  | _        t          j        |j        |j                  | _	        t          j        |j
        |j                  | _        t          j        |j        |j                  | _        t          j        |j                  | _        |                     dt%          j        |j                                      d          d           t+          |dd          | _        |                     d	t%          j        | j                                        t$          j        
          d           d S )N)padding_idxepsposition_ids)r#   F)
persistentposition_embedding_typeabsolutetoken_type_idsdtype)super__init__r   	Embedding
vocab_sizeembedding_sizepad_token_idword_embeddingsmax_position_embeddingsposition_embeddingstype_vocab_sizetoken_type_embeddings	LayerNormlayer_norm_epsDropouthidden_dropout_probdropoutregister_bufferr]   arangeexpandrT   r}   zerosrz   sizelongselfrb   	__class__s     r1   r   zElectraEmbeddings.__init__   sM   !|F,=v?Tbhbuvvv#%<0NPVPe#f#f %'\&2H&J_%`%`" f&;AVWWWz&"<== 	EL)GHHOOPWXXej 	 	
 	
 	
 (/v7PR\']']$ek$*;*@*@*B*B%*UUUbg 	 	
 	
 	
 	
 	
r3   Nr   	input_idsr   rz   inputs_embedspast_key_values_lengthreturnc                    ||                                 }n|                                 d d         }|d         }|| j        d d |||z   f         }|mt          | d          r2| j        d d d |f         }|                    |d         |          }	|	}n+t          j        |t
          j        | j        j                  }|| 	                    |          }| 
                    |          }
||
z   }| j        dk    r|                     |          }||z  }|                     |          }|                     |          }|S )Nr{   r#   r   r   r   devicer~   )r   rz   hasattrr   r   r]   r   r   r   r   r   r}   r   r   r   )r   r   r   rz   r   r   input_shape
seq_lengthbuffered_token_type_ids buffered_token_type_ids_expandedr   
embeddingsr   s                r1   forwardzElectraEmbeddings.forward   sm     #..**KK',,..ss3K ^
,QQQ0FVlIl0l-lmL
 !t-.. m*.*=aaa*n*M'3J3Q3QR]^_R`bl3m3m0!A!&[
SWSdSk!l!l!l  00;;M $ : :> J J"%::
':55"&":":<"H"H--J^^J//
\\*--
r3   )NNNNr   )__name__
__module____qualname____doc__r   r   r]   
LongTensorFloatTensorrV   Tensorr   __classcell__r   s   @r1   ru   ru      s        QQ
 
 
 
 
. 15593759&'' 'E,-' !!12' u/0	'
   12' !$' 
' ' ' ' ' ' ' 'r3   ru   c                       e Zd Zd fd	Z eddd          	 	 	 	 	 	 ddej        d	eej                 d
eej                 deej                 dee	         dee
         deej                 deej                 fd            Z xZS )ElectraSelfAttentionNc                 R   t                                                       |j        |j        z  dk    r0t	          |d          s t          d|j         d|j         d          |j        | _        t          |j        |j        z            | _        | j        | j        z  | _        t          j
        |j        | j                  | _        t          j
        |j        | j                  | _        t          j
        |j        | j                  | _        t          j        |j                  | _        |pt#          |dd          | _        | j        dk    s| j        d	k    r6|j        | _        t          j        d
|j        z  dz
  | j                  | _        |j        | _        || _        d S )Nr   r   zThe hidden size (z6) is not a multiple of the number of attention heads ()r}   r~   relative_keyrelative_key_queryr=   r#   )r   r   hidden_sizenum_attention_headsr   rZ   rV   attention_head_sizeall_head_sizer   Linearquerykeyvaluer   attention_probs_dropout_probr   rT   r}   r   r   distance_embedding
is_decoder	layer_idxr   rb   r}   r   r   s       r1   r   zElectraSelfAttention.__init__   s    ::a??PVXhHiHi?8F$6 8 8 48 8 8  
 $*#= #&v'9F<V'V#W#W !58PPYv143EFF
9V/1CDDYv143EFF
z&"EFF'> (
'-zC
 C
$ '>99T=Y]q=q=q+1+ID(&(l1v7U3UXY3Y[_[s&t&tD# +"r3   past_key_valuepast_key_values4.58new_nameversionFhidden_statesattention_mask	head_maskencoder_hidden_statesoutput_attentionscache_positionr   c                    |j         \  }}	}
|                     |          }|                    |d| j        | j                                      dd          }d}|d u}|Ht          |t                    r1|j        	                    | j
                  }|r|j        }n
|j        }n|}|r|n|}|r3|1|r/|j        | j
                 j        }|j        | j
                 j        }n|                     |          }|                    |d| j        | j                                      dd          }|                     |          }|                    |d| j        | j                                      dd          }|N|s|nd }|                    ||| j
        d|i          \  }}|r$t          |t                    rd|j        | j
        <   t'          j        ||                    dd                    }| j        dk    s| j        d	k    rt|j         d         |j         d         }}|>t'          j        |dz
  t&          j        |j        
                              dd          }n:t'          j        |t&          j        |j        
                              dd          }t'          j        |t&          j        |j        
                              dd          }||z
  }|                     || j        z   dz
            }|                    |j                  }| j        dk    rt'          j        d||          }||z   }n?| j        d	k    r4t'          j        d||          }t'          j        d||          }||z   |z   }|t?          j         | j                  z  }|||z   }tB          j"        #                    |d          }| $                    |          }|||z  }t'          j        ||          }|%                    dddd          &                                }|'                                d d         | j(        fz   }|                    |          }||fS )Nr{   r#   r=   Fr   Tr   r   r   r   zbhld,lrd->bhlrzbhrd,lrd->bhlrdimr   r   ))rY   r   viewr   r   rX   rN   r   
is_updatedgetr   cross_attention_cacheself_attention_cachelayerskeysvaluesr   r   updater]   matmulr}   tensorr   r   r   r   r   tor   einsummathsqrtr   
functionalsoftmaxr   permute
contiguousr   r   )r   r   r   r   r   r   r   r   
batch_sizer   _query_layerr   is_cross_attentioncurr_past_key_valuecurrent_states	key_layervalue_layerattention_scoresquery_length
key_lengthposition_ids_lposition_ids_rdistancepositional_embeddingrelative_position_scoresrelative_position_scores_queryrelative_position_scores_keyattention_probscontext_layernew_context_layer_shapes                                  r1   r   zElectraSelfAttention.forward   s    %2$7!
Jjj//!&&z2t7OQUQijjttq
 
 
2$>&/+>?? 6,7;;DNKK
% O*9*O''*9*N''&5#2DW..- 	F/"=*"=+24>BGI-4T^DKKK00I!z2t7OQUQijjtt1 I **^44K%**B 8$:R i1oo  *7I!St)<)C)C{DN=M~<^* *&	; & F*_FY*Z*Z FAEO.t~> !<Y5H5HR5P5PQQ'>99T=Y]q=q=q'2'8';Y_Q=O*L*!&j1nEJWdWk!l!l!l!q!q" " "'l%*UbUi!j!j!j!o!oprtu!v!v"\*EJ}OcdddiijkmoppN%6H#'#:#:8dFb;bef;f#g#g #7#:#:AR#:#S#S +~==+0<8H+Wk+l+l(#36N#N  -1EEE16>NP[]q1r1r./4|<LiYm/n/n,#36T#TWs#s +di8P.Q.QQ%/.@ -//0@b/II ,,77  -	9O_kBB%--aAq99DDFF"/"4"4"6"6ss";t?Q>S"S%**+BCCo--r3   NNNNNNFN)r   r   r   r   r"   r]   r   r   r   r   booltupler   r   r   s   @r1   r   r      s       # # # # # #6 _%0A6RRR 7;15=A+/,115e. e.|e. !!23e. E-.	e.
  ((9:e. "%e. $D>e. !.e. 
u|	e. e. e. SRe. e. e. e. e.r3   r   c                   P     e Zd Z fdZdej        dej        dej        fdZ xZS )ElectraSelfOutputc                    t                                                       t          j        |j        |j                  | _        t          j        |j        |j                  | _        t          j        |j	                  | _
        d S Nrx   )r   r   r   r   r   denser   r   r   r   r   r   s     r1   r   zElectraSelfOutput.__init__N  sf    Yv163EFF
f&8f>STTTz&"<==r3   r   input_tensorr   c                     |                      |          }|                     |          }|                     ||z             }|S Nr  r   r   r   r   r  s      r1   r   zElectraSelfOutput.forwardT  @    

=11]33}|'CDDr3   r   r   r   r   r]   r   r   r   r   s   @r1   r  r  M  i        > > > > >U\  RWR^        r3   r  eagerc                       e Zd Zd fd	Zd Z eddd          	 	 	 	 	 	 dd	ej        d
eej	                 deej	                 deej	                 dee
         dee         deej                 deej                 fd            Z xZS )ElectraAttentionNc                     t                                                       t          |j                 |||          | _        t          |          | _        t                      | _        d S )Nr}   r   )	r   r   ELECTRA_SELF_ATTENTION_CLASSES_attn_implementationr   r  outputsetpruned_headsr   s       r1   r   zElectraAttention.__init__b  sc    263NO$;
 
 
	
 (//EEr3   c                    t          |          dk    rd S t          || j        j        | j        j        | j                  \  }}t          | j        j        |          | j        _        t          | j        j        |          | j        _        t          | j        j	        |          | j        _	        t          | j
        j        |d          | j
        _        | j        j        t          |          z
  | j        _        | j        j        | j        j        z  | j        _        | j                            |          | _        d S )Nr   r#   r   )rU   r   r   r   r   r  r   r   r   r   r  r  r   union)r   headsindexs      r1   prune_headszElectraAttention.prune_headsl  s    u::??F7490$)2OQUQb
 
u
 -TY_eDD	*49=%@@	,TY_eDD	.t{/@%QOOO )-	(EE

(R	%"&)"?$)B_"_	 -33E::r3   r   r   r   r   Fr   r   r   r   r   r   r   c           	          |                      |||||||          }|                     |d         |          }	|	f|dd          z   }
|
S )Nr   r   r   r   r   r   r   r#   )r   r  )r   r   r   r   r   r   r   r   self_outputsattention_outputoutputss              r1   r   zElectraAttention.forward~  sf     yy)"7+/) ! 
 
  ;;|AFF#%QRR(88r3   r   r   )r   r   r   r   r  r"   r]   r   r   r   r   r   r   r   r   r   s   @r1   r  r  a  s       " " " " " "; ; ;$ _%0A6RRR 7;15=A+/,115 | !!23 E-.	
  ((9: "% $D> !. 
u|	   SR    r3   r  c                   B     e Zd Z fdZdej        dej        fdZ xZS )ElectraIntermediatec                    t                                                       t          j        |j        |j                  | _        t          |j        t                    rt          |j                 | _        d S |j        | _        d S r  )r   r   r   r   r   intermediate_sizer  rN   
hidden_actstrr   intermediate_act_fnr   s     r1   r   zElectraIntermediate.__init__  sn    Yv163KLL
f'-- 	9'-f.?'@D$$$'-'8D$$$r3   r   r   c                 Z    |                      |          }|                     |          }|S r  )r  r(  )r   r   s     r1   r   zElectraIntermediate.forward  s,    

=1100??r3   r  r   s   @r1   r#  r#    s^        9 9 9 9 9U\ el        r3   r#  c                   P     e Zd Z fdZdej        dej        dej        fdZ xZS )ElectraOutputc                    t                                                       t          j        |j        |j                  | _        t          j        |j        |j                  | _        t          j	        |j
                  | _        d S r  )r   r   r   r   r%  r   r  r   r   r   r   r   r   s     r1   r   zElectraOutput.__init__  sf    Yv79KLL
f&8f>STTTz&"<==r3   r   r  r   c                     |                      |          }|                     |          }|                     ||z             }|S r  r	  r
  s      r1   r   zElectraOutput.forward  r  r3   r  r   s   @r1   r+  r+    r  r3   r+  c                   0    e Zd Zd fd	Z eddd          	 	 	 	 	 	 	 ddej        d	eej                 d
eej                 deej                 deej                 dee	         dee
         deej                 deej                 fd            Zd Z xZS )ElectraLayerNc                    t                                                       |j        | _        d| _        t	          ||          | _        |j        | _        |j        | _        | j        r0| j        st          |  d          t	          |d|          | _	        t          |          | _        t          |          | _        d S )Nr#   r   z> should be used as a decoder model if cross attention is addedr~   r  )r   r   chunk_size_feed_forwardseq_len_dimr  	attentionr   add_cross_attentionrZ   crossattentionr#  intermediater+  r  r   rb   r   r   s      r1   r   zElectraLayer.__init__  s    '-'E$)&IFFF +#)#= # 	t? j D!h!h!hiii"26S]ir"s"s"sD/77#F++r3   r   r   r   r   Fr   r   r   r   encoder_attention_maskr   r   r   c	           	      h   |                      ||||||          }	|	d         }
|	dd          }| j        rV|Tt          | d          st          d|  d          |                     |
||||||          }|d         }
||dd          z   }t          | j        | j        | j        |
          }|f|z   }|S )N)r   r   r   r   r   r   r#   r6  z'If `encoder_hidden_states` are passed, z` has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`r  )	r4  r   r   rZ   r6  r   feed_forward_chunkr2  r3  )r   r   r   r   r   r9  r   r   r   self_attention_outputsr   r!  cross_attention_outputslayer_outputs                 r1   r   zElectraLayer.forward  s!    "&)/+) "0 "
 "
 2!4(,? 	<4@4!122  Dd D D D  
 '+&9&9 5#&; /"3- ': ' '#  7q9 7 ;;G0#T%A4CSUe
 
  /G+r3   c                 \    |                      |          }|                     ||          }|S r  )r7  r  )r   r   intermediate_outputr>  s       r1   r;  zElectraLayer.feed_forward_chunk  s2    "//0@AA{{#68HIIr3   r  )NNNNNFN)r   r   r   r   r"   r]   r   r   r   r   r   r   r   r;  r   r   s   @r1   r/  r/    s.       , , , , , , _%0A6RRR 7;15=A>B+/,115. .|. !!23. E-.	.
  ((9:. !)): ;. "%. $D>. !.. 
u|	. . . SR.`      r3   r/  c                   H    e Zd Zd fd	Z	 	 	 	 	 	 	 	 	 	 ddej        deej                 deej                 deej                 d	eej                 d
ee         dee	         dee	         dee	         dee	         deej                 de
eej                 ef         fdZ xZS )ElectraEncoderNc                     t                                                       | _        t          j        fdt          j                  D                       | _        d| _        d S )Nc                 2    g | ]}t          |           S )r1  )r/  )r/   irb   s     r1   
<listcomp>z+ElectraEncoder.__init__.<locals>.<listcomp>  s&    #o#o#o!L1$E$E$E#o#o#or3   F)	r   r   rb   r   
ModuleListrangenum_hidden_layerslayergradient_checkpointingr8  s    ` r1   r   zElectraEncoder.__init__  sa    ]#o#o#o#ouU[UmOnOn#o#o#opp
&+###r3   FTr   r   r   r   r9  r   	use_cacher   output_hidden_statesreturn_dictr   r   c                    |	rdnd }|rdnd }|r| j         j        rdnd }| j        r%| j        r|rt                              d           d}|rD| j         j        r8|6t          t          | j                   t          | j                             }|rO| j         j        rCt          |t                    r.t                              d           t          j        |          }t          | j                  D ]Z\  }}|	r||fz   }|||         nd } |||||||||          }|d         }|r$||d         fz   }| j         j        r||d	         fz   }[|	r||fz   }|
st          d
 |||||fD                       S t          |||||          S )Nr.   zZ`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...F)rb   zPassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.)r9  r   r   r   r   r#   r=   c              3      K   | ]}||V  	d S r  r.   )r/   vs     r1   r2   z)ElectraEncoder.forward.<locals>.<genexpr>D  s4       
 
 =  !===
 
r3   )last_hidden_stater   r   
attentionscross_attentions)rb   r5  rK  trainingrC   warning_oncer   r   r   rN   r   from_legacy_cache	enumeraterJ  r   )r   r   r   r   r   r9  r   rL  r   rM  rN  r   all_hidden_statesall_self_attentionsall_cross_attentionsrE  layer_modulelayer_head_masklayer_outputss                      r1   r   zElectraEncoder.forward  sc    #7@BBD$5?bb4%6d4;;Zdrr`d& 	"4= 	" "##p   "	 	v/ 	vO4K1,dk2R2R2RT`hlhsTtTtTtuuO 	U/ 	UJPU4V4V 	U\  
 2COTTO(44 	V 	VOA|# I$58H$H!.7.CillO(L%'= /"3-	 	 	M *!,M  V&9]1=M<O&O#;2 V+?=QRCSBU+U( 	E 1]4D D 	 
 
 "#%'(
 
 
 
 
 
 9+++*1
 
 
 	
r3   r  )
NNNNNNFFTN)r   r   r   r   r]   r   r   r   r   r   r   r   r   r   r   r   s   @r1   rB  rB    sP       , , , , , , 7;15=A>B+/$(,1/4&*15P
 P
|P
 !!23P
 E-.	P

  ((9:P
 !)): ;P
 "%P
 D>P
 $D>P
 'tnP
 d^P
 !.P
 
uU\"$MM	NP
 P
 P
 P
 P
 P
 P
 P
r3   rB  c                   (     e Zd ZdZ fdZd Z xZS )ElectraDiscriminatorPredictionszEPrediction module for the discriminator, made up of two dense layers.c                    t                                                       t          j        |j        |j                  | _        t          |j                  | _        t          j        |j        d          | _	        || _
        d S Nr#   )r   r   r   r   r   r  r   r&  
activationr)   rb   r   s     r1   r   z(ElectraDiscriminatorPredictions.__init__[  sf    Yv163EFF
():;; "	&*<a @ @r3   c                     |                      |          }|                     |          }|                     |                              d          }|S )Nr{   )r  rc  r)   squeeze)r   discriminator_hidden_statesr   logitss       r1   r   z'ElectraDiscriminatorPredictions.forwardc  sK    

#>??66&&}55==bAAr3   r   r   r   r   r   r   r   r   s   @r1   r`  r`  X  sM        OO          r3   r`  c                   (     e Zd ZdZ fdZd Z xZS )ElectraGeneratorPredictionszAPrediction module for the generator, made up of two dense layers.c                    t                                                       t          d          | _        t	          j        |j        |j                  | _        t	          j        |j	        |j                  | _
        d S )Ngelurx   )r   r   r   rc  r   r   r   r   r   r   r  r   s     r1   r   z$ElectraGeneratorPredictions.__init__n  sa    (00f&;AVWWWYv163HII


r3   c                     |                      |          }|                     |          }|                     |          }|S r  )r  rc  r   )r   generator_hidden_statesr   s      r1   r   z#ElectraGeneratorPredictions.forwardu  s<    

#:;;66}55r3   rh  r   s   @r1   rj  rj  k  sR        KKJ J J J J      r3   rj  c                   ,    e Zd ZU eed<   eZdZdZd Z	dS )ElectraPreTrainedModelrb   electraTc                    t          |t          j                  rT|j        j                            d| j        j                   |j         |j        j        	                                 dS dS t          |t          j
                  r_|j        j                            d| j        j                   |j        +|j        j        |j                 	                                 dS dS t          |t          j                  r?|j        j        	                                 |j        j                            d           dS dS )zInitialize the weightsg        )meanstdNg      ?)rN   r   r   r6   r_   normal_rb   initializer_ranger9   zero_r   rw   r   fill_)r   modules     r1   _init_weightsz$ElectraPreTrainedModel._init_weights  s)   fbi(( 	* M&&CT[5R&SSS{& &&((((( '&-- 	*M&&CT[5R&SSS!-"6#56<<>>>>> .--- 	*K""$$$M$$S)))))	* 	*r3   N)
r   r   r   r$   __annotations__rs   load_tf_weightsbase_model_prefixsupports_gradient_checkpointingrz  r.   r3   r1   rp  rp  }  sB         0O!&*#* * * * *r3   rp  z3
    Output type of [`ElectraForPreTraining`].
    )custom_introc                       e Zd ZU dZdZeej                 ed<   dZ	eej                 ed<   dZ
eeej                          ed<   dZeeej                          ed<   dS )ElectraForPreTrainingOutputa+  
    loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
        Total loss of the ELECTRA objective.
    logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
        Prediction scores of the head (scores for each token before SoftMax).
    Nlossrg  r   rS  )r   r   r   r   r  r   r]   r   r{  rg  r   r   rS  r.   r3   r1   r  r    s           )-D(5$
%,,,*.FHU&'...8<M8E%"345<<<59Ju01299999r3   r  c                        e Zd Z fdZd Zd Zd Ze	 	 	 	 	 	 	 	 	 	 	 	 	 ddee	j
                 dee	j
                 dee	j
                 d	ee	j
                 d
ee	j
                 dee	j
                 dee	j
                 dee	j
                 dee         dee         dee         dee         dee         deee	j
                 ef         fd            Z xZS )ElectraModelc                 8   t                                          |           t          |          | _        |j        |j        k    r$t          j        |j        |j                  | _        t          |          | _
        || _        |                                  d S r  )r   r   ru   r   r   r   r   r   embeddings_projectrB  encoderrb   	post_initr   s     r1   r   zElectraModel.__init__  s       +F33 F$666&(i0EvGY&Z&ZD#%f--r3   c                     | j         j        S r  r   r   r   s    r1   get_input_embeddingsz!ElectraModel.get_input_embeddings  s    ..r3   c                     || j         _        d S r  r  )r   r   s     r1   set_input_embeddingsz!ElectraModel.set_input_embeddings  s    */'''r3   c                     |                                 D ]/\  }}| j        j        |         j                            |           0dS )z
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
        class PreTrainedModel
        N)itemsr  rJ  r4  r  )r   heads_to_prunerJ  r  s       r1   _prune_headszElectraModel._prune_heads  sU    
 +0022 	C 	CLE5Lu%/;;EBBBB	C 	Cr3   Nr   r   r   rz   r   r   r   r9  r   rL  r   rM  rN  r   c                    ||n| j         j        }||n| j         j        }||n| j         j        }||t	          d          |+|                     ||           |                                }n.||                                d d         }nt	          d          |\  }}||j        n|j        }d}|	Bt          |	t                    s|	d         d         j
        d         n|	                                }|t          j        ||          }|gt          | j        d          r1| j        j        d d d |f         }|                    ||          }|}n!t          j        |t          j        |          }|                     ||          }| j         j        rL|J|                                \  }}}||f}|t          j        ||          }|                     |          }nd }|                     || j         j                  }|                     |||||	          }t          | d
          r|                     |          }|                     ||||||	|
|||
  
        }|S )NzDYou cannot specify both input_ids and inputs_embeds at the same timer{   z5You have to specify either input_ids or inputs_embedsr   r   )r   r   r   )r   rz   r   r   r   r  )	r   r   r   r9  r   rL  r   rM  rN  )rb   r   rM  use_return_dictrZ   %warn_if_padding_and_no_attention_maskr   r   rN   r   rY   get_seq_lengthr]   onesr   r   r   r   r   r   get_extended_attention_maskr   invert_attention_maskget_head_maskrI  r  r  )r   r   r   r   rz   r   r   r   r9  r   rL  r   rM  rN  r   r   r   r   r   r   r   extended_attention_maskencoder_batch_sizeencoder_sequence_lengthr   encoder_hidden_shapeencoder_extended_attention_maskr   s                               r1   r   zElectraModel.forward  s   " 2C1N--TXT_Tq$8$D  $+Jj 	 &1%<kk$+B] ]%>cddd"66y.QQQ#..**KK&',,..ss3KKTUUU!,
J%.%:!!@T!"& "/5996"1%+B//$3355 # !"ZFCCCN!t(899 [*./*HKZK*X'3J3Q3QR\^h3i3i0!A!&[
SY!Z!Z!Z"&"B"B>S^"_"_ ;! 	3&;&G=R=W=W=Y=Y: 7$68O#P %-).4HQW)X)X)X&.2.H.HI_.`.`++.2+&&y$+2OPP	%)'#9 ( 
 
 4-.. 	C 33MBBM2"7#B+/!5# % 
 
 r3   )NNNNNNNNNNNNN)r   r   r   r   r  r  r  r    r   r]   r   r   r   r   r   r   r   r   r   s   @r1   r  r    s       
 
 
 
 
/ / /0 0 0C C C  -11515/3,0048<9=+/$(,0/3&*\ \EL)\ !.\ !.	\
 u|,\ EL)\  -\  (5\ !) 6\ "%\ D>\ $D>\ 'tn\ d^\ 
uU\"$FF	G\ \ \ ^\ \ \ \ \r3   r  c                   (     e Zd ZdZ fdZd Z xZS )ElectraClassificationHeadz-Head for sentence-level classification tasks.c                 \   t                                                       t          j        |j        |j                  | _        |j        |j        n|j        }t          d          | _	        t          j
        |          | _        t          j        |j        |j                  | _        d S )Nrl  )r   r   r   r   r   r  classifier_dropoutr   r   rc  r   r   
num_labelsout_projr   rb   r  r   s      r1   r   z"ElectraClassificationHead.__init__(  s    Yv163EFF
)/)B)NF%%TZTn 	 )00z"455	&"4f6GHHr3   c                     |d d dd d f         }|                      |          }|                     |          }|                     |          }|                      |          }|                     |          }|S )Nr   )r   r  rc  r  )r   featureskwargsxs       r1   r   z!ElectraClassificationHead.forward2  sn    QQQ111WLLOOJJqMMOOALLOOMM!r3   rh  r   s   @r1   r  r  %  sR        77I I I I I      r3   r  c                   j     e Zd ZdZdef fdZ	 d	dej        deej	                 dej        fdZ
 xZS )
ElectraSequenceSummarya  
    Compute a single vector summary of a sequence hidden states.

    Args:
        config ([`ElectraConfig`]):
            The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
            config class of your model for the default values it uses):

            - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:

                - `"last"` -- Take the last token hidden state (like XLNet)
                - `"first"` -- Take the first token hidden state (like Bert)
                - `"mean"` -- Take the mean of all tokens hidden states
                - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
                - `"attn"` -- Not implemented now, use multi-head attention

            - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
            - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
              (otherwise to `config.hidden_size`).
            - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
              another string or `None` will add no activation.
            - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
            - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
    rb   c                 V   t                                                       t          |dd          | _        | j        dk    rt          t          j                    | _        t          |d          rW|j	        rPt          |d          r|j
        r|j        dk    r|j        }n|j        }t          j        |j        |          | _        t          |dd           }|rt          |          nt          j                    | _        t          j                    | _        t          |d          r)|j        dk    rt          j        |j                  | _        t          j                    | _        t          |d	          r+|j        dk    r"t          j        |j                  | _        d S d S d S )
Nsummary_typelastattnsummary_use_projsummary_proj_to_labelsr   summary_activationsummary_first_dropoutsummary_last_dropout)r   r   rT   r  NotImplementedErrorr   Identitysummaryr   r  r  r  r   r   r   rc  first_dropoutr  r   last_dropoutr  )r   rb   num_classesactivation_stringr   s       r1   r   zElectraSequenceSummary.__init__W  s   #FNFCC&& &%{}}6-.. 	F63J 	Fv788 1V=Z 1_e_pst_t_t$/$09V%7EEDL#F,@$GGIZ$mN3D$E$E$E`b`k`m`m[]]6233 	J8TWX8X8X!#F,H!I!IDKMM6122 	Hv7RUV7V7V "
6+F G GD	H 	H7V7Vr3   Nr   	cls_indexr   c                 :   | j         dk    r|dddf         }n-| j         dk    r|dddf         }n| j         dk    r|                    d          }n| j         d	k    r|=t          j        |d
ddddf         |j        d         dz
  t          j                  }nl|                    d                              d          }|                    d|                                dz
  z  |	                    d          fz             }|
                    d|                              d          }n| j         dk    rt          |                     |          }|                     |          }|                     |          }|                     |          }|S )ak  
        Compute a single vector summary of a sequence hidden states.

        Args:
            hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
                The hidden states of the last layer.
            cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
                Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.

        Returns:
            `torch.FloatTensor`: The summary of the sequence hidden states.
        r  Nr{   firstr   rs  r#   r   r  .r   r   )r{   r  )r  rs  r]   	full_likerY   r   	unsqueezer   r   r   gatherre  r  r  r  rc  r  )r   r   r  r  s       r1   r   zElectraSequenceSummary.forwardt  s    &&"111b5)FF'))"111a4(FF&(("''A'..FF+-- !O!#rr111*-!'+a/*  		 &//33==bAA	%,,Uimmoo6I-JmN`N`acNdNdMf-fgg	"))"i88@@DDFF&((%%##F++f%%((""6**r3   r  )r   r   r   r   r$   r   r]   r   r   r   r   r   r   s   @r1   r  r  =  s         2H} H H H H H H< Y]) )".);CEDT;U)		) ) ) ) ) ) ) )r3   r  z
    ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the
    pooled output) e.g. for GLUE tasks.
    c                   \    e Zd Z fdZe	 	 	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 deej                 deej                 deej                 d	eej                 d
ee         dee         dee         de	e
ej                 ef         fd            Z xZS ) ElectraForSequenceClassificationc                     t                                          |           |j        | _        || _        t	          |          | _        t          |          | _        |                                  d S r  )	r   r   r  rb   r  rq  r  r<   r  r   s     r1   r   z)ElectraForSequenceClassification.__init__  sb        +#F++3F;; 	r3   Nr   r   r   rz   r   r   labelsr   rM  rN  r   c                    |
|
n| j         j        }
|                     ||||||||	|
	  	        }|d         }|                     |          }d}|Z| j         j        f| j        dk    rd| j         _        nN| j        dk    r7|j        t          j        k    s|j        t          j	        k    rd| j         _        nd| j         _        | j         j        dk    rWt                      }| j        dk    r1 ||                                |                                          }n |||          }n| j         j        dk    rGt                      } ||                    d| j                  |                    d                    }n*| j         j        dk    rt                      } |||          }|
s|f|dd         z   }||f|z   n|S t          |||j        |j        	          S )
a  
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        Nr   r   rz   r   r   r   rM  rN  r   r#   
regressionsingle_label_classificationmulti_label_classificationr{   r  rg  r   rS  )rb   r  rq  r<   problem_typer  r   r]   r   rV   r
   re  r	   r   r   r   r   rS  )r   r   r   r   rz   r   r   r  r   rM  rN  rf  sequence_outputrg  r  loss_fctr  s                    r1   r   z(ElectraForSequenceClassification.forward  s   ( &1%<kk$+B]&*ll))%'/!5# '3 
'
 
'
# 6a811{'/?a''/;DK,,_q((flej.H.HFL\a\eLeLe/LDK,,/KDK,{'<77"99?a''#8FNN$4$4fnn6F6FGGDD#8FF33DD)-JJJ+--xB @ @&++b//RR)-III,..x// 	FY!<QRR!@@F)-)9TGf$$vE'5C2=	
 
 
 	
r3   
NNNNNNNNNN)r   r   r   r   r    r   r]   r   r   r   r   r   r   r   r   s   @r1   r  r    sL             -11515/3,004)-,0/3&*D
 D
EL)D
 !.D
 !.	D

 u|,D
 EL)D
  -D
 &D
 $D>D
 'tnD
 d^D
 
uU\"$<<	=D
 D
 D
 ^D
 D
 D
 D
 D
r3   r  z
    Electra model with a binary classification head on top as used during pretraining for identifying generated tokens.

    It is recommended to load the discriminator checkpoint into that model.
    c                   \    e Zd Z fdZe	 	 	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 deej                 deej                 deej                 d	eej                 d
ee         dee         dee         de	e
ej                 ef         fd            Z xZS )ElectraForPreTrainingc                     t                                          |           t          |          | _        t	          |          | _        |                                  d S r  )r   r   r  rq  r`  discriminator_predictionsr  r   s     r1   r   zElectraForPreTraining.__init__  sP       #F++)H)P)P&r3   Nr   r   r   rz   r   r   r  r   rM  rN  r   c                    |
|
n| j         j        }
|                     ||||||||	|
	  	        }|d         }|                     |          }d}|t	          j                    }|s|                    d|j        d                   dk    }|                    d|j        d                   |         }||         } |||                                          }n= ||                    d|j        d                   |                                          }|
s|f|dd         z   }||f|z   n|S t          |||j
        |j                  S )am  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the ELECTRA loss. Input should be a sequence of tokens (see `input_ids` docstring)
            Indices should be in `[0, 1]`:

            - 0 indicates the token is an original token,
            - 1 indicates the token was replaced.

        Examples:

        ```python
        >>> from transformers import ElectraForPreTraining, AutoTokenizer
        >>> import torch

        >>> discriminator = ElectraForPreTraining.from_pretrained("google/electra-base-discriminator")
        >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-discriminator")

        >>> sentence = "The quick brown fox jumps over the lazy dog"
        >>> fake_sentence = "The quick brown fox fake over the lazy dog"

        >>> fake_tokens = tokenizer.tokenize(fake_sentence, add_special_tokens=True)
        >>> fake_inputs = tokenizer.encode(fake_sentence, return_tensors="pt")
        >>> discriminator_outputs = discriminator(fake_inputs)
        >>> predictions = torch.round((torch.sign(discriminator_outputs[0]) + 1) / 2)

        >>> fake_tokens
        ['[CLS]', 'the', 'quick', 'brown', 'fox', 'fake', 'over', 'the', 'lazy', 'dog', '[SEP]']

        >>> predictions.squeeze().tolist()
        [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
        ```Nr  r   r{   r#   r  )rb   r  rq  r  r   r   r   rY   floatr  r   rS  )r   r   r   r   rz   r   r   r  r   rM  rN  rf  discriminator_sequence_outputrg  r  r  active_lossactive_logitsactive_labelsr  s                       r1   r   zElectraForPreTraining.forward	  s   Z &1%<kk$+B]&*ll))%'/!5# '3 
'
 
'
# )DA(F%//0MNN+--H),11"6S6YZ[6\]]abb &B0M0STU0V W WXc d &{ 3x}/B/B/D/DEExB0M0STU0V W WY_YeYeYgYghh 	FY!<QRR!@@F)-)9TGf$$vE*5C2=	
 
 
 	
r3   r  )r   r   r   r   r    r   r]   r   r   r   r   r  r   r   r   s   @r1   r  r    sL             -11515/3,004)-,0/3&*Q
 Q
EL)Q
 !.Q
 !.	Q

 u|,Q
 EL)Q
  -Q
 &Q
 $D>Q
 'tnQ
 d^Q
 
uU\"$??	@Q
 Q
 Q
 ^Q
 Q
 Q
 Q
 Q
r3   r  z
    Electra model with a language modeling head on top.

    Even though both the discriminator and generator may be loaded into this model, the generator is the only model of
    the two to have been trained for the masked language modeling task.
    c                   n    e Zd ZdgZ fdZd Zd Ze	 	 	 	 	 	 	 	 	 	 ddee	j
                 dee	j
                 dee	j
                 d	ee	j
                 d
ee	j
                 dee	j
                 dee	j
                 dee         dee         dee         deee	j
                 ef         fd            Z xZS )rO   generator_lm_head.weightc                 
   t                                          |           t          |          | _        t	          |          | _        t          j        |j        |j	                  | _
        |                                  d S r  )r   r   r  rq  rj  generator_predictionsr   r   r   r   generator_lm_headr  r   s     r1   r   zElectraForMaskedLM.__init__i  sj       #F++%@%H%H"!#6+@&BS!T!Tr3   c                     | j         S r  r  r  s    r1   get_output_embeddingsz(ElectraForMaskedLM.get_output_embeddingss      %%r3   c                     || _         d S r  r  )r   r   s     r1   set_output_embeddingsz(ElectraForMaskedLM.set_output_embeddingsv  s    !0r3   Nr   r   r   rz   r   r   r  r   rM  rN  r   c                    |
|
n| j         j        }
|                     ||||||||	|
	  	        }|d         }|                     |          }|                     |          }d}|Pt          j                    } ||                    d| j         j                  |                    d                    }|
s|f|dd         z   }||f|z   n|S t          |||j
        |j                  S )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        Nr  r   r{   r#   r  )rb   r  rq  r  r  r   r	   r   r   r   r   rS  )r   r   r   r   rz   r   r   r  r   rM  rN  rn  generator_sequence_outputprediction_scoresr  r  r  s                    r1   r   zElectraForMaskedLM.forwardy  s2   ( &1%<kk$+B]"&,,))%'/!5# #/ 
#
 
#
 %<A$>! 667PQQ 223DEE*,,H8-222t{7MNNPVP[P[\^P_P_``D 	F'),CABB,GGF)-)9TGf$$vE$1?.9	
 
 
 	
r3   r  )r   r   r   _tied_weights_keysr   r  r  r    r   r]   r   r   r   r   r   r   r   r   s   @r1   rO   rO   ^  s`        55    & & &1 1 1  -11515/3,004)-,0/3&*4
 4
EL)4
 !.4
 !.	4

 u|,4
 EL)4
  -4
 &4
 $D>4
 'tn4
 d^4
 
uU\"N2	34
 4
 4
 ^4
 4
 4
 4
 4
r3   rO   z
    Electra model with a token classification head on top.

    Both the discriminator and generator may be loaded into this model.
    c                   \    e Zd Z fdZe	 	 	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 deej                 deej                 deej                 d	eej                 d
ee         dee         dee         de	e
ej                 ef         fd            Z xZS )ElectraForTokenClassificationc                 V   t                                          |           |j        | _        t          |          | _        |j        |j        n|j        }t          j        |          | _	        t          j
        |j        |j                  | _        |                                  d S r  )r   r   r  r  rq  r  r   r   r   r   r   r   r<   r  r  s      r1   r   z&ElectraForTokenClassification.__init__  s        +#F++)/)B)NF%%TZTn 	 z"455)F$68IJJr3   Nr   r   r   rz   r   r   r  r   rM  rN  r   c                    |
|
n| j         j        }
|                     ||||||||	|
	  	        }|d         }|                     |          }|                     |          }d}|Ft                      } ||                    d| j                  |                    d                    }|
s|f|dd         z   }||f|z   n|S t          |||j	        |j
                  S )z
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
        Nr  r   r{   r#   r  )rb   r  rq  r   r<   r	   r   r  r   r   rS  )r   r   r   r   rz   r   r   r  r   rM  rN  rf  r  rg  r  r  r  s                    r1   r   z%ElectraForTokenClassification.forward  s   $ &1%<kk$+B]&*ll))%'/!5# '3 
'
 
'
# )DA(F%(,5R(S(S%!>??'))H8FKKDO<<fkk"ooNND 	FY!<QRR!@@F)-)9TGf$$vE$5C2=	
 
 
 	
r3   r  )r   r   r   r   r    r   r]   r   r   r   r   r   r   r   r   s   @r1   r  r    s8             -11515/3,004)-,0/3&*1
 1
EL)1
 !.1
 !.	1

 u|,1
 EL)1
  -1
 &1
 $D>1
 'tn1
 d^1
 
uU\"$99	:1
 1
 1
 ^1
 1
 1
 1
 1
r3   r  c                       e Zd ZU eed<   dZ fdZe	 	 	 	 	 	 	 	 	 	 	 ddee	j
                 dee	j
                 dee	j
                 dee	j
                 d	ee	j
                 d
ee	j
                 dee	j
                 dee	j
                 dee         dee         dee         deee	j
                 ef         fd            Z xZS )ElectraForQuestionAnsweringrb   rq  c                     t                                          |           |j        | _        t          |          | _        t          j        |j        |j                  | _        | 	                                 d S r  )
r   r   r  r  rq  r   r   r   
qa_outputsr  r   s     r1   r   z$ElectraForQuestionAnswering.__init__   se        +#F++)F$68IJJ 	r3   Nr   r   r   rz   r   r   start_positionsend_positionsr   rM  rN  r   c           
         ||n| j         j        }|                     |||||||	|
          }|d         }|                     |          }|                    dd          \  }}|                    d                                          }|                    d                                          }d }||t          |                                          dk    r|                    d          }t          |                                          dk    r|                    d          }|                    d          }|	                    d|          }|	                    d|          }t          |          } |||          } |||          }||z   dz  }|s||f|dd          z   }||f|z   n|S t          ||||j        |j                  S )	N)r   r   rz   r   r   r   rM  r   r#   r{   r   )ignore_indexr=   )r  start_logits
end_logitsr   rS  )rb   r  rq  r  rQ   re  r   rU   r   clampr	   r   r   rS  )r   r   r   r   rz   r   r   r  r  r   rM  rN  rf  r  rg  r  r  
total_lossignored_indexr  
start_lossend_lossr  s                          r1   r   z#ElectraForQuestionAnswering.forward
  s    &1%<kk$+B]&*ll))%'/!5 '3 	'
 	'
# 6a811#)<<r<#:#: j#++B//::<<''++6688

&=+D?''))**Q.."1"9"9""="==%%''((1,, - 5 5b 9 9(--a00M-33A}EEO)//=AAM']CCCH!,@@Jx
M::H$x/14J 	R ,ABB/0F 0:/EZMF**6Q+%!5C2=
 
 
 	
r3   )NNNNNNNNNNN)r   r   r   r$   r{  r}  r   r    r   r]   r   r   r   r   r   r   r   r   s   @r1   r  r    su        !      -11515/3,0042604,0/3&*@
 @
EL)@
 !.@
 !.	@

 u|,@
 EL)@
  -@
 "%,/@
  -@
 $D>@
 'tn@
 d^@
 
uU\"$@@	A@
 @
 @
 ^@
 @
 @
 @
 @
r3   r  c                   \    e Zd Z fdZe	 	 	 	 	 	 	 	 	 	 ddeej                 deej                 deej                 deej                 deej                 deej                 d	eej                 d
ee         dee         dee         de	e
ej                 ef         fd            Z xZS )ElectraForMultipleChoicec                     t                                          |           t          |          | _        t	          |          | _        t          j        |j        d          | _	        | 
                                 d S rb  )r   r   r  rq  r  sequence_summaryr   r   r   r<   r  r   s     r1   r   z!ElectraForMultipleChoice.__init__P  sh       #F++ 6v > >)F$6:: 	r3   Nr   r   r   rz   r   r   r  r   rM  rN  r   c                    |
|
n| j         j        }
||j        d         n|j        d         }|)|                    d|                    d                    nd}|)|                    d|                    d                    nd}|)|                    d|                    d                    nd}|)|                    d|                    d                    nd}|=|                    d|                    d          |                    d                    nd}|                     ||||||||	|
	  	        }|d         }|                     |          }|                     |          }|                    d|          }d}|t                      } |||          }|
s|f|dd         z   }||f|z   n|S t          |||j
        |j                  S )a[  
        input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
            Indices of input sequence tokens in the vocabulary.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
            1]`:

            - 0 corresponds to a *sentence A* token,
            - 1 corresponds to a *sentence B* token.

            [What are token type IDs?](../glossary#token-type-ids)
        position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.max_position_embeddings - 1]`.

            [What are position IDs?](../glossary#position-ids)
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
            num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
            `input_ids` above)
        Nr#   r{   r   r  r   r  )rb   r  rY   r   r   rq  r   r<   r	   r   r   rS  )r   r   r   r   rz   r   r   r  r   rM  rN  num_choicesrf  r  pooled_outputrg  reshaped_logitsr  r  r  s                       r1   r   z ElectraForMultipleChoice.forwardZ  s1   X &1%<kk$+B],5,Aioa((}GZ[\G]>G>SINN2y~~b'9'9:::Y]	M[Mg,,R1D1DR1H1HIIImqM[Mg,,R1D1DR1H1HIIImqGSG_|((\->->r-B-BCCCei ( r=#5#5b#9#9=;M;Mb;Q;QRRR 	 '+ll))%'/!5# '3 
'
 
'
# 6a8--o>>// ++b+66'))H8OV44D 	F%'*Eabb*IIF)-)9TGf$$vE("5C2=	
 
 
 	
r3   r  )r   r   r   r   r    r   r]   r   r   r   r   r   r   r   r   s   @r1   r  r  N  sL             -11515/3,004)-,0/3&*X
 X
EL)X
 !.X
 !.	X

 u|,X
 EL)X
  -X
 &X
 $D>X
 'tnX
 d^X
 
uU\"$==	>X
 X
 X
 ^X
 X
 X
 X
 X
r3   r  zS
    ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.
    c            "           e Zd ZdgZ fdZd Zd Ze	 	 	 	 	 	 	 	 	 	 	 	 	 	 ddee	j
                 dee	j
                 dee	j
                 d	ee	j
                 d
ee	j
                 dee	j
                 dee	j
                 dee	j
                 dee	j
                 dee         dee         dee         dee         dee         deee	j
                 ef         fd            Z xZS )ElectraForCausalLMr  c                 L   t                                          |           |j        st                              d           t          |          | _        t          |          | _        t          j
        |j        |j                  | _        |                                  d S )NzOIf you want to use `ElectraForCausalLM` as a standalone, add `is_decoder=True.`)r   r   r   rC   warningr  rq  rj  r  r   r   r   r   r  init_weightsr   s     r1   r   zElectraForCausalLM.__init__  s         	nNNlmmm#F++%@%H%H"!#6+@&BS!T!Tr3   c                     | j         S r  r  r  s    r1   r  z(ElectraForCausalLM.get_output_embeddings  r  r3   c                     || _         d S r  r  )r   new_embeddingss     r1   r  z(ElectraForCausalLM.set_output_embeddings  s    !/r3   Nr   r   r   rz   r   r   r   r9  r  r   rL  r   rM  rN  r   c                    ||n| j         j        }|	d}|                     |||||||||
||||          }|d         }|                     |                     |                    }d}|	 | j        ||	fd| j         j        i|}|s|f|dd         z   }||f|z   n|S t          |||j        |j	        |j
        |j                  S )a3  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
            `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
            ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`

        Example:

        ```python
        >>> from transformers import AutoTokenizer, ElectraForCausalLM, ElectraConfig
        >>> import torch

        >>> tokenizer = AutoTokenizer.from_pretrained("google/electra-base-generator")
        >>> config = ElectraConfig.from_pretrained("google/electra-base-generator")
        >>> config.is_decoder = True
        >>> model = ElectraForCausalLM.from_pretrained("google/electra-base-generator", config=config)

        >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
        >>> outputs = model(**inputs)

        >>> prediction_logits = outputs.logits
        ```NF)r   r   rz   r   r   r   r9  r   rL  r   rM  rN  r   r   r#   )r  rg  r   r   rS  rT  )rb   r  rq  r  r  loss_functionr   r   r   r   rS  rT  )r   r   r   r   rz   r   r   r   r9  r  r   rL  r   rM  rN  r  r!  r  r  lm_lossr  s                        r1   r   zElectraForCausalLM.forward  s>   R &1%<kk$+B]I,,))%'"7#9+/!5#  
 
  "!* 2243M3Mo3^3^__(d(!   ;1 	 G  	L')GABBK7F,3,?WJ''VK0$#3!/)$5
 
 
 	
r3   )NNNNNNNNNNNNNN)r   r   r   r  r   r  r  r    r   r]   r   r   r   r   r   r   r   r   r   s   @r1   r  r    s        55
 
 
 
 
& & &0 0 0  -11515/3,0048<9=)-+/$(,0/3&*S
 S
EL)S
 !.S
 !.	S

 u|,S
 EL)S
  -S
  (5S
 !) 6S
 &S
 "%S
 D>S
 $D>S
 'tnS
 d^S
" 
uU\"$EE	F#S
 S
 S
 ^S
 S
 S
 S
 S
r3   r  )
r  rO   r  r  r  r  r  r  rp  rs   )r%   )Nr   r   rE   dataclassesr   typingr   r   r   r]   r   torch.nnr   r	   r
   activationsr   r   cache_utilsr   r   r   
generationr   modeling_layersr   modeling_outputsr   r   r   r   r   r   r   r   modeling_utilsr   pytorch_utilsr   r   r   utilsr   r    r!   utils.deprecationr"   configuration_electrar$   
get_loggerr   rC   rs   Moduleru   r   r  r  r  r#  r+  r/  rB  r`  rj  rp  r  r  r  r  r  r  rO   r  r  r  r  __all__r.   r3   r1   <module>r      sO      				 ! ! ! ! ! ! , , , , , , , , , ,        A A A A A A A A A A 1 1 1 1 1 1 1 1 C C C C C C C C C C ) ) ) ) ) ) 9 9 9 9 9 9	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 . - - - - - l l l l l l l l l l 9 9 9 9 9 9 9 9 9 9 0 0 0 0 0 0 0 0 0 0 0 0 
	H	%	%O O O Od? ? ? ? ?	 ? ? ?FB. B. B. B. B.29 B. B. B.L    	    !" 3 3 3 3 3ry 3 3 3n    ")        BI   C C C C C- C C CNW
 W
 W
 W
 W
RY W
 W
 W
t    bi   &    ")   $ * * * * *_ * * *.   
: : : : :+ : :  : x x x x x) x x xv    	   0` ` ` ` `RY ` ` `F   P
 P
 P
 P
 P
'= P
 P
 P
f   [
 [
 [
 [
 [
2 [
 [
 [
|   H
 H
 H
 H
 H
/ H
 H
 H
V   @
 @
 @
 @
 @
$: @
 @
 @
F O
 O
 O
 O
 O
"8 O
 O
 O
d d
 d
 d
 d
 d
5 d
 d
 d
N   
i
 i
 i
 i
 i
/ i
 i
 
i
X  r3   