
    `i                     
   U d dl Z d dlZd dlZd dlZd dlmZmZ d dlmZm	Z	m
Z
 d dlmZ d dlmZmZmZmZmZmZ d dlZd dlmZ d dlmZ d dlmZ d dlmZmZmZm Z m!Z!m"Z" d dl#m$Z$ d d	l%m&Z&m'Z'm(Z)m*Z*m+Z+m,Z,m-Z-m.Z. d d
l/m0Z0m1Z1 d dl2m3Z3 d dl4m5Z5 d dl6m7Z8 d dl9m:Z: g dZ;dZ<dZ=dZ>dZ?e@eA         ZBee3eejC        eDeEeAf         ZFeeFeGeF         eHeF         eIeAdf         f         ZJeIeAeJf         ZKeGeK         ZLeIeAeeKeLf         f         ZM e@            ZNe@e         eOd<   e jP        d             ZQe	 G d d                      ZRe	 G d deR                      ZS	 	 	 dKdejT        deAdeAd eUd!eUd"eBfd#ZV G d$ d%          ZWdLd&ZXddd'dejT        d(eHejY        jZ        d)f         d*eUd+ee@ejT                          d,eeR         d"eSfd-Z[d.eIeAeJf         d/eMd0eSd"dfd1Z\d2eejT        ejY        jZ        f         d3eAd"efd4Z]d5eIeAef         d0eSd"eIeAef         fd6Z^ ej_                    dejT        d0eSd"eIeAeJf         fd7            Z` ej_                    dejT        d5eIeAeJf         d0eSd"e5fd8            Zad9ejY        jZ        d"dfd:Zbd5eMd"eIeAeJf         fd;Zcd9ejY        jZ        d5eIeAeJf         d0eSd"eMfd<Zd ej_                    dejT        d=eHejY        jZ        d)f         d0eSd"eMfd>            ZedejT        d9ejY        jZ        d/eMd0eSd"eMf
d?Zf ej_                    dejT        d=eHejY        jZ        d)f         d5eMd0eSd"df
d@            Zgddd'dejT        d+ee@ejT                          d,eeR         d"eIeAeJf         fdAZhddd'dejT        d=eejY        jZ        eejY        jZ                 f         d+ee@ejT                          d,eeR         d"eMf
dBZiddd'dejT        d=eejY        jZ        eejY        jZ                 f         d+ee@ejT                          d,eeR         d"eHeIeAeJf         eMf         f
dCZjdejT        d5eeIejT        eIeAeJf         f         eIeAeJf         f         d"eIeAeJf         fdDZkddEdejT        d.eIeAeJf         d,eeR         d"e5fdFZlddEdejT        d=eejY        jZ        eejY        jZ                 f         d/eMd,eeR         d"df
dGZmddEdejT        d=eejY        jZ        eejY        jZ                 f         d.eIeAeJf         d/eMd,eeR         d"e5fdHZneddEdejT        d,eeR         d"dfdI            ZoeddEdejT        d=eHejY        jZ        d)f         d,eeR         d"dfdJ            ZpdS )M    N)	GeneratorIterable)asdict	dataclassfield)chain)AnyCallablecastno_type_checkOptionalUnion)ShardedTensor)_broadcast_state_dict_distribute_state_dict_flatten_state_dict_gather_state_dict_offload_state_dict_to_cpu_unflatten_state_dict)_CHECKPOINT_PREFIX)FullOptimStateDictConfigFullStateDictConfigFullyShardedDataParallelOptimStateDictConfigShardedOptimStateDictConfigShardedStateDictConfigStateDictConfigStateDictType)._get_module_fsdp_state_if_fully_sharded_moduleFSDP_WRAPPED_MODULE)DTensor)_IncompatibleKeys)DistributedDataParallel)tree_map_only)FQNS_TPrimitiveType	ValueTypeDictValueTypeListDictValueTypeOptimizerStateTypeStateDictOptionsget_model_state_dictget_optimizer_state_dictget_state_dictset_model_state_dictset_optimizer_state_dictset_state_dict_flat_paramparam_groupsparamsstater'   _patched_state_dictc               #      K   t          j                    } t          j                     	 d V  | rt          j                     d S d S # | rt          j                     w w xY wN)gc	isenableddisableenable)
is_enableds    {/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/torch/distributed/checkpoint/state_dict.py_gc_contextr?   Q   sh      JJLLL 	IKKKKK	 	: 	IKKKK	s   A Ac                       e Zd ZU dZdZeed<   dZeed<   dZeed<   dZ	eed<   dZ
eed<   dZeed	<   dZeed
<   dZeed<   dS )r+   ap  
    This dataclass specifies how get_state_dict/set_state_dict will work.

    - ``full_state_dict``: if this is set to True, all the tensors in the
      returned state_dict will be gathered. No ShardedTensor and DTensor
      will be in the returned state_dict.

    - ``cpu_offload``: offload all the tensors to cpu. To prevent CPU OOM, if
      ``full_state_dict`` is also true, then only the rank0 will get the
      state_dict and all other ranks will get empty state_dict.

    - ``ignore_frozen_params``: if the value is True, the returned state_dict
      won't contain any frozen parameters -- the ``requires_grad`` is False.
      The default value is False.

    - ``keep_submodule_prefixes`` (deprecated): when ``submodules`` is not None, this option
      indicates whether to keep the submodule prefixes from the state_dict keys.
      or example, if the submodule is ``module.pretrain`` and the full FQN of
      the parameter is ``pretrain.layer1.weight`` of the param. When this option
      is True, the parameter's key in the returned state_dict will be
      ``pretrain.layer1.weight``. If the options is False, the key will be
      ``layer1.weight``.
      Note that if ``keep_submodule_prefixes`` is False, there may be conflicted
      FQNs, hence there should be only one submodule in ``submodules``.

    - ``strict``: the ``strict`` option when ``set_state_dict`` calls
      model.load_state_dict().

    - ``broadcast_from_rank0``: when the option is True, rank0 should receive a
       full state_dict and will broadcast the tensors in the state_dict/
       optim_state_dict one by one to other ranks. Other ranks will receive
       the tensors and shard according to the local shards in the model and
       optimizer. ``full_state_dict`` must be set to True when using this option.
       This option currently only supports DTensor, not the legacy ShardedTensor.
    Ffull_state_dictcpu_offloadignore_frozen_paramsTkeep_submodule_prefixesstrictbroadcast_from_rank0flatten_optimizer_state_dict_fqn_modifiersdsd_fqn_modifiersN)__name__
__module____qualname____doc__rA   bool__annotations__rB   rC   rD   rE   rF   rG   rI   str     r>   r+   r+   \   s         " "H "OT!!!K!&$&&&$(T(((FD!&$&&&). $...-s-----rR   r+   c                   v   e Zd ZU  ee          Zeeeej	        f         ee
ej	        f         f         ed<    ee          Zeeeej	        f         ee
ej	        f         f         ed<    ee          Zee         ed<   dZeed<   dZeed<   ej        Zeed<    ee          Zeej                 ed	<   d
S )_StateDictInfo)default_factoryfqn_param_mappingshared_params_mappingsubmodule_prefixesThandle_modelhandle_optimfsdp_contextfsdp_modulesN)rJ   rK   rL   r   dictrV   r   rP   torchTensorr%   rO   rW   setrX   rY   rN   rZ   
contextlibnullcontextr[   r
   listr\   nnModulerQ   rR   r>   rT   rT      s,        
 	d### tc5< fel"#	% $ $ $ 	d### 4c5< fel"#	% $ $ $ $)5#=#=#=C===L$L$'3L(333$)E$$?$?$?L$ry/?????rR   rT   rH   TmodelnamerI   skip_ddp_prefixskip_compiler_prefixreturnc                 4   |                     t          d          }d|vr|hS |                    d          }g }| }t          |          D ]\  }}	t	          |t
                    r'|	dk    sJ |j        }|s|                    |	           Bt	          |t                    r|t          |          dz
  k     rZ||dz            t          k    rFd                    |          t          |t                    }
r dfd|
j        D             c S t          |t                    }|	t          k    r%|                    |	           t          ||	          }t	          |t          j        j        j                  r(|	dk    sJ |j        }|s|                    |	           Zt)          ||          rM t          ||                                          |	          x}r t)          ||          rt          ||          }|                    |	           |	t,          j        j        j        k    r'|t          |          dz
  k    rt3          d          t          ||	          }d                    |                               t          d          hS )a  
    This API is used to convert the name of a parameter to the FQNs. For FSDP
    without `use_orig_params`, the name of FlatParameter can be mapped to
    multiple original parameters. As a result, the return type of this function
    is `set[str]`.

    Args:
        module (nn.Module): the root model.
        name (str): the name
        skip_ddp_prefix (bool): whether to skip DDP's `module` prefix

    Returns:
        The canonical FQNs based on the model traversal.
     .module   c                     h | ]} | 	S rQ   rQ   ).0fqnprefixs     r>   	<setcomp>z_get_fqns.<locals>.<setcomp>   s$    EEES6(3((EEErR   	_orig_modz-Expect `_extra_state` to be the last obj name)replacer   split	enumerate
isinstanceDDPrn   appendFSDPlen_FLAT_PARAMjoingetattr_fqnsr    r^   _dynamo
eval_frameOptimizedModuleru   hasattrgetrd   modules_EXTRA_STATE_KEY_SUFFIXRuntimeError)rf   rg   rI   rh   ri   	obj_namesfqn_obj_namescurr_objicurr_obj_name
flat_paramremoved_fqnrs   s               @r>   	_get_fqnsr      s   . <<*B//D
$v

3IMH%i00 $< $<=h$$ #	< H,,,,H" 4$$]333$'' 	<3y>>A%%%)AE*:k*I*I-00$X{;;
 * &\\\FEEEEJ4DEEEEEEx)<==H 333$$]333"8];;%-":"JKK 	< K////)H' 4$$]333 x!233 B"F'(4E"F"F"H"H"L"L!# # ; B x55 B#*8[#A#A  ///
 1 IIII***&'VWWW + #8];;HH]##++,>CCDDrR   c                       e Zd ZdS )_EXTRA_STATEN)rJ   rK   rL   rQ   rR   r>   r   r      s        DrR   r   c              #      K   t                      dt          j        dt          dt          ffd | d          E d {V  d S )Nrn   curr_fqnrj   c              3     K                        |            |r| dnd}|                                 D ]i\  }}|v r
t          |           r7| t          |                                                       v r|d d         }n| | } ||          E d {V  jt          |                     d          |                     d                    D ]\  }}|| j        v r| | }||fV  t          | j	        dt          j        j                  t          j        j        k    r.| t          j        j        j         }|t!                      fV  d S d S )Nrm   rl   F)recurseget_extra_state)addnamed_childrenr   r   valuesr   named_buffersnamed_parameters_non_persistent_buffers_set	__class__rd   re   r   r   rn   r   r   )	rn   r   rg   	submodulenew_fqnobjrI   r   visited_moduless	         r>   r   z+_iterate_valid_model_state.<locals>.recurse   s     F###%-5h>>>>2%4466 	3 	3OD)O++  122.>GF,=>>@@GGIIII #3B3-%-t--wy'2222222222   //1H1HQV1H1W1W
 
 	 	ID# v999!)4))G3, F$&79RSSy() ) "N2:#4#LNNG<>>))))))	) )rR   rl   )r`   rd   re   rP   r   )rf   rI   r   r   s    `@@r>   _iterate_valid_model_stater      s~      &)eeO *	  *S  *Y  *  *  *  *  *  *  *  *D wub!!!!!!!!!!!rR   )
submodulesoptionsoptims.
optim_onlyr   r   c                d   |rt          j        dt                     |r|st          d          |pt	                      }i }i }t          |           D ]\  }}t          |t                    rt          | |          }	|	                    |d          }
|
Et          t          t                   ||                                       |	           ||         ||<   n|	                                ||<   |	D ]}
t          |t                    s|||
<   t          |                                          D ]'\  }}|D ]}
t          t"          j        |          ||
<    (t                      }|rzt          |          }|                                 D ]V\  }}||vr
t          | |          }	t)          |	          dk    s
J d            |                    d |	D                        W|j        r|j        st/          d          t1          j        |           }|r|j        rJt5          |j        |j                  }t9          |j        |j        p|j                  }t:          j        }n6t?          |j        	          }tA          |j        	          }t:          j!        }tD          j#        d
             }tI          j%        || |||          }ntD          j&        }tO          di tQ          |          ||||t          t          tR          j*                 |          | t)          |          dk    dS )zW
    Verify the model and options passed by the user and generates _StateDictInfo.
    zGetting submodules only model/optim state_dict is deprecated and will be removed in 2.5. This feature can be achieved by manually filtering out the state_dict returned from get_state_dict.z;Optimizers are not passed in but optim_only is set to True.Nro   z)Submodule FQN should only have 1 instancec              3       K   | ]	}| d V  
dS )rm   NrQ   )rq   rr   s     r>   	<genexpr>z"_verify_options.<locals>.<genexpr>K  s(      %@%@Ciii%@%@%@%@%@%@rR   z?full_state_dict must be True when broadcast_from_rank0 is True.)offload_to_cpu
rank0_only)r   c              3     K   t          j                    5  t          j        ddt                     t	          j        | |||          5  d V  d d d            n# 1 swxY w Y   d d d            d S # 1 swxY w Y   d S )NignorezFSDP.state_dict_type)messagecategoryrn   state_dict_typestate_dict_configoptim_state_dict_config)warningscatch_warningsfilterwarningsFutureWarningr|   r   r   s       r>   $fsdp_state_dict_type_without_warningz=_verify_options.<locals>.fsdp_state_dict_type_without_warningi  s#      (** 
 
'&<}    )!$3&7,C	     EEE              	
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
s5   5A4AA4A 	 A4#A 	$A44A8;A8r   r   )rV   rW   rX   r[   r\   rY   rZ   rQ   )+r   warnr   r   r+   r   ry   r   r   r   r   r`   rP   updatecopyrc   itemsr^   r_   named_modulesr}   rF   rA   
ValueErrorr|   r\   r   rB   r   r   FULL_STATE_DICTr   r   SHARDED_STATE_DICTra   contextmanager	functoolspartialrb   rT   r   rd   re   )rf   r   r   r   r   rV   rW   rg   paramfqnsrr   param_fqns_rX   rn   r\   r   r   r   r   r[   s                        r>   _verify_optionsr     s     
I 		
 	
 	
  
& 
I
 
 	
 +)++G 	 
 	  2%88 / /ee\** 	%%##E400?S,U344;;DAAA+<U+C!%(( (,yy{{e$ 	/ 	/Ce\22 /).!#&	/ 399;;<< D D 	D 	DC)-elF)C)C!#&&	D $'55 A__
!//11 	A 	ALD&Z''UD))Dt99>>>#N>>>%%%@%@4%@%@%@@@@@# 
G,C 
M
 
 	
 $U++L  /." 	? 3&2w?R! ! ! '?&2#/O73O' ' '# ,;OO 6&2! ! ! 'B&2' ' '# ,>O		"	 	 
#	"	$ !(0+/$;
 
 
 "- 	 	
//	+3-!$ry/<88#^&kkAo	 	 	 	 	rR   model_state_dictoptim_state_dictinfoc                    |j         D ]}t          |          }|
J d            |j        rP| sN|j        sG|j        s@|j        r|j        s2|j        r+|j        s$t          dt          j                    d          |j        r)|s'|j        r|j        s|j        st          d|           |                                 D ]%}t          |v rt          | dt           d          &d S )Nz)Expected a fsdp_state with a fsdp module.z}The option indicates that model state_dict is required to save or load, but model state_dict is empty.rank = dist.get_rank()=rm   zgThe option indicates that model state_dict is required to save, or load but optim state_dict is empty. z
 contains z6. This can happen if the model is not the root module.)r\   r   rY   rX   rC   rB   rA   rE   rF   r   distget_rankrZ   keysr~   )r   r   r   rn   
fsdp_statekeys         r>   _verify_state_dictr     s   
 # S SCFKK
%%'R%%%%
 	
 
 '
 )	

 !

 '+&:
 K
 )
 *moo* * *
 
 	
  	 	%	*.*>	 .	
 M:JM M  
  $$&&  # * *+ * * *    rR   r   apic                     t          | |          }|t          v r)t          j        t          | j        |          |           }|S )N)self)r   r6   r   r   r   )r   r   calls      r>   _state_dict_fnr     sC    3D""" !<!<3GGGKrR   
state_dictc                     |j         r@|j        rt          j                                        sdnd}t          | |j        |          S |j        rt          |           S | S )NrQ   )r   )rB   
ranks_only)rA   rB   r^   distributedis_initializedr   r   )r   r   r   s      r>   _maybe_full_or_cpu_state_dictr     s       $,1,=,L,L,N,NBB 	
 "D$4
 
 
 	
 
	 )*555rR   c                    |j         si S |                                5   t          | d                      }d d d            n# 1 swxY w Y   t          |                                          D ]}t          | |          }t          |          dk    sJ ||f            t          t          |                    }||k    rDdt          fd} |||          st          d| d|           |                    |          ||<   |j        rpi }|                                D ]W}|j        D ]M}|                    |          s|j        r||         ||<   +|t          |          d          }	||         ||	<   NX|}|j        rL|                                 D ]7\  }}
|
j        rt          | |          }|D ]}|                    |           8t%          ||          S )Nr   ro   rj   c                 T   t          |          t          |           k    rdS |                    d          }|                     d          }d}t          |          D ]I\  }}|||         k    r1|dz  }|t          |          k    r|t          |          dz
  k    c S B|dv rG dS dS )NFrm   r   ro   )rn   ru   T)r}   rw   rx   )r   rr   	fqn_split	key_splitfqn_idxkey_idxkey_names          r>   verifyz%_get_model_state_dict.<locals>.verify  s    s88s3xx'' 5IIcNN	IIcNN	)29)=)= % %%GX9W#5551"c)nn44#*c)nnq.@#@@@@ 5!%<<< $uutrR   zAn unexpected key, z, exists. FQN is )rY   r[   r   rc   r   r   r}   nextiterrN   r   poprX   
startswithrD   rC   r   requires_gradr   )rf   r   r   r   r   rr   r   new_state_dictrs   r   r   s              r>   _get_model_state_dictr     s     						 ; ;8^E<88::
; ; ; ; ; ; ; ; ; ; ; ; ; ; ; JOO%%&& 2 2$$4yyA~~~T{~~~4::#::D    " 6#s## V"#T#T#Ts#T#TUUU(nnS11JsO $/1??$$ 	> 	>C1 > >~~f-- / >*4S/N3''!#f++--0G.8oN7++> $
  $0022 	$ 	$JC" UC((D $ $s####$ )T:::s   AA
Ac           	         |j         r	|s|j        st          i i           S i }t          | |j                  D ]\  }}t          | ||j                  }t          | ||j        dd          }t          ||          D ]f\  }}	|j        rt          j                    dk    r>||	k    r8|	                    |d           }
|
|j
        rt          d| d          n|
||	<   |||	<   gd}|j        s|j        rt                      }|                                D ]K\  }}t          j        |          r2|                                dk    r|                    |j                   Lt          j        d          |v r)|                    t          j        d                     d}t+          |          dk    r2|                    t          j                                                   n"t+          |          dk    rt1          d	          |j        r1t3          |||	                                |j
        |j        
           n+|j        r$t7          |||	                                           |                    |           |                                5  t=          t           t?          | d          ||j
        |                    cd d d            S # 1 swxY w Y   d S )NF)rh   ri   r   zMissing key: rm   metaTro   zMultiple devices found)devicerE   rB   r   load_state_dict)r   rE   assign) rY   rF   r"   r   rI   r   zipr   r   r   rE   r   rA   r`   r   r^   	is_tensordimr   r   remover}   distributed_c10d_get_pg_default_devicer   r   rB   r   r   r[   r   r   )rf   r   r   local_state_dictr   valuer   fqns_with_prefixrr   fqn_with_prefix
load_valuer   devicess                r>   _load_model_state_dictr    sw     )Z )8Q ) R(((08NOO 6 6
UT%;<<$"!!&
 
 
 %(.>$?$? 
	6 
	6 C-=15A1E1E(('^^C66
%{ C*+A3+A+A+ABBBC 3=J/05_--
	6 F  ,D$8 ,%%*0022 	* 	*JCu%% *%))++//EL))) <7**NN5<//000Fw<<1KK-DDFFGGGG\\A5666$ 		W! {{}}{ ,     ! 	W":/?VVVV*+++					 
 
4N5"344%dk&  
 

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
s   4KKKoptimc                 B   | j         rdS | j        D ]}|t                   D ]}|j          dS | j        D ]2}|t                   D ]"}|j        rt          j        |          |_        #3g }| j        D ]Z}d|v rT|                    |d                    t          |d         t
          j	                  rt          j
        d          nd|d<   [|                     d           | j        D ]}d|v r|                    d          |d<   |                     d           dS )zH
    Initialize optim states by calling the step() with zero grads.
    Nlrg        )closurer   T)set_to_none)r5   r3   _PARAMSgradr   r^   
zeros_liker{   ry   r_   tensorstepr   	zero_grad)r  param_groupr   lrss       r>   _init_optim_stater  ]  sx    {  )   ) 	 	Ez% &	 ) 5 5 ) 	5 	5E" 5"-e44
	5 C)  ;JJ{4())) k$/>>S!!! 
 
JJtJ ) + +; #

K	OOO%%%%%rR   c           
      ,   d }i }t          t          | t                                                             D ]O\  }}t          t          |                                          D ]"\  }} ||           ||t           d| d| <   #Pt          t          | t
                             D ]k}|                    t                    }t          t          t                   |          D ].}|                                D ]\  }}||t
           d| d| <   /l|S )aI  
    This API flattens the optimizer state_dict to support optimizer resharding for
    MPMD, e.g., pipeline parallelism.

    Without the API, the original optimizer state_dict looks like:
    {
        "state": {
            "layer1.weight": {
                "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor
            },
            "layer2.weight": {
                "step": 10, "exp_avg": SomeTensor, "exp_avg_sq": SomeTensor
            },
        },
        "param_group": [
            {
                "lr": 0.0,
                "betas": (0.9, 0.95), ...,
                "params": ["layer1.weight", "layer2.weight"]
            }
        ]
    }

    With this API, the optimizer state_dict looks like:
    {
        "state.layer1.weight.step": 10,
        "state.layer2.weight.step": 10,
        "state.layer1.weight.exp_avg": SomeTensor,
        "state.layer2.weight.exp_avg": SomeTensor,
        "state.layer1.weight.exp_avg_sq": SomeTensor,
        "state.layer2.weight.exp_avg_sq": SomeTensor,
        "param_group.layer1.weight.lr" : 0.1,
        "param_group.layer2.weight.lr" : 0.1,
        "param_group.layer1.weight.betas" : (0.9, 0.95),
        "param_group.layer2.weight.betas" : (0.9, 0.95),
    }

    Note that if any of the value is a container, like the betas in the example,
    this API won't flattent it.
    c                     t          | t          j        t          t          f          s t          dt          |            d          d S )NzUFlattening optimizer state_dict only supports tensor, int, float states now. Type is rm   )ry   r^   r_   intfloatNotImplementedErrortype)vs    r>   _raise_if_type_not_supportedz?_flatten_optim_state_dict.<locals>._raise_if_type_not_supported  sV    !elC788 	%&77& & &  	 	rR   rm   )
r   r(   _STATEr   r)   _PGr   r  rc   rP   )	r   r  retrr   r5   kr  r  r   s	            r>   _flatten_optim_state_dictr    sJ   T   !#C=*V*<==CCEE + +
U..4466 	+ 	+DAq((+++)*C6%%C%%!%%&&	+ -z#?? , ,w''S	4(( 	, 	,C#))++ , ,1*+s&&S&&1&&'',	, JrR   c                    i }g }t           |t          |i}| j        D ]}|                    t          g i           |t                   D ]}|j        |         D ]}||j        v r;d}	|                                D ]#}
|
t          k    rt           d| d|
 }||v rd}	 nd}	|	sK|d         t                   }t          |t                    sJ |                    |           |j
        si ||<   | j        |                                         D ]3}|t            d| d|          t          t          ||                   |<   4t          t          t                   |d         t                             d         }|                                D ]s}
|
t          k    r|t           d| d|
          }|
|d         vr||d         |
<   9|d         |
         |k    r(t          d| d|
 d| d|d         |
          d	          t|S )	z
    This API unflattens the state_dict generated by _flatten_optim_state_dict().
    See the docstring of _flatten_optim_state_dict() for more detail.
    Frm   Tr   r   zaAll the parameters in the same parameter group should have the same saved param_group value. But z is z while other(s) is )r  r  r3   r{   r  rV   rW   r   ry   rc   r   r5   r   r(   rP   r   )r  r   r   r5   pg_state
return_osdr  r   rr   	in_paramsr  flatten_keyr4   
state_namefirst_param_fqnr   s                   r>   _unflatten_optim_state_dictr%    s    E"$H&,eS(%CJ) - -"&&& ) 	 	E-e4   $444 %I(--//  <<$),&8&8s&8&8Q&8&8&*44(,I $I  !"g.!&$/////c"""* c
"'+e"4"9"9";";  JBL!66C66*66CDc
33J??3< tCy(2,w*?@@C!!## 	 	AG||#====!==>E$$"'Q"aE))"G=LG GOPG GG G4<RLOG G G   *	 rR   
optimizersc                    |j         si S t          i t          g i}|D ]}t          |            t	          |d                      }|j        r|                                5  t          j        | ||          }d d d            n# 1 swxY w Y   |svt          |t                   
                                          D ]H}d|v rB|t                                       |          |t                   |                    dd          <   I|t                   D ]#}d |t                   D             }||t          <   $nqt          t          j        d |j        D                                 }t#          t%          |t'          t)          |                                        }	i |                                 D ]]\  }
}t-          | |
          }t)          |          dk    sJ t/          t1          |                    }||	vrK|	|         }||<   ||<   ^t          |t                   
                                          D ]8}
|
         }|t                                       |
          |t                   |<   9|t                   D ]#}fd|t                   D             |t          <   $|st3          t4          |t                                                 |t                              t3          t8          |t                                                 |t                              |j        r"t3          t>          tA          |                    }tC          ||          S )	Nr   ru   
_orig_mod.rl   c                 :    g | ]}|                     d d          S )r(  rl   rv   rq   r  s     r>   
<listcomp>z)_get_optim_state_dict.<locals>.<listcomp>!  s&    JJJ!!))L"55JJJrR   c              3   0   K   | ]}|t                    V  d S r8   )r  )rq   gs     r>   r   z(_get_optim_state_dict.<locals>.<genexpr>$  s&      -U-UQaj-U-U-U-U-U-UrR   ro   c                      g | ]
}|         S rQ   rQ   )rq   pidfqn_pid_mappings     r>   r,  z)_get_optim_state_dict.<locals>.<listcomp>6  s    !Q!Q!Q3/#"6!Q!Q!QrR   )"rZ   r  r  r  r   r\   r[   r|   r   rc   r   r   rv   r  r   from_iterabler3   r]   r   ranger}   r   r   r   r   r   r(   r   r)   extendrG   r*   r  r   )rf   r&  r   r   r  osdr  r.  r4   param_pid_mappingr   r   r   rr   r0  groupr1  s                   @r>   _get_optim_state_dictr8    s     	,2BR+@ ,H ,H%   1nUL1133 #	R""$$ ? ?+E5#>>? ? ? ? ? ? ? ? ? ? ? ? ? ? ?  #f+**,,-- R R!##?B6{q?Q?QCK		, ; ;<X $ $JJqzJJJ#'

$ %--U-U%BT-U-U-UUUVVF $Ss6{{1C1C%D%D E E O#4466 + +
U ,,4yyA~~~~4::&& 111'.'*$'*$$CK,,..// 8 8%c*#&v;??3#7#7FC  S R R!Q!Q!Q!Q%.!Q!Q!Qg 	],V455<<S[III 0 566==c#hGGGG( 
 9:J K K
 
 ))94@@@s   "BB		B		c           
      D   i }g }t           |t          |i}i }t          d t          t          |t                                                              D                       r|S |j        D ]}|                    t          g i           |t                   D ]f}	|j	        |	         D ]T}
|
|j
        v rWd}t          t          |t                             D ]3}|
t          t          t                   |t                             v rd} n4nd}|sh|d         t                   }t          |t                    sJ |                    |
           |	j        r)t          t          |t                              |
         ||
<   t          t          |t                             D ]\}|
t          t          t                   |t                             v r-t!          |t                             dz
  |t#          |          <   ]Vht!          |t                             dk    rg }t          t          |t                             D ]S}t!          t          t          t                   |t                                       dk    r|                    |           Tt!          |          dk    rt%          d          t!          |t                             t!          |j                  k    rt%          d          t!          |t                             dz
  |t#          |          <   t          t          |t                             D ]]}|                    t#          |          d          }|dk    r,|                                D ]\  }}|t          k    r|||         |<   ^|S )	a  
    Extract the corresponding optim state_dict from ``optim_state_dict`` for
    ``optim`` and return the result optim state_dict.

    Args:
        model (nn.Module): the root model.
        optim (torch.optim.Optimizer): the optimizer.
        optim_state_dict (Dict[str, ValueType]): the superset optim state_dict that
            contains the optim state_dict of ``optim``.
        info (_StateDictInfo): state dict information.

    Returns:
        The optim state_dict of ``optim``.
    c              3   @   K   | ]}t          |t                    V  d S r8   )ry   r  r+  s     r>   r   z*_split_optim_state_dict.<locals>.<genexpr>`  s=         
1c     rR   FTr   ro   r   zThere are param groups that have zero parameters. In such a case, DSD only support exactly one param group with zero parameters.But the loaded state_dict has zero or more than one param groups that have zero parameters.z`When there is a parameter group that has zero parameters, multiple optimizers are not supported.)r  r  allr   r(   r   r3   r{   r  rV   rW   r)   rc   rP   ry   r   r}   idr   r   r   )rf   r  r   r   r5   r  r   
pg_mappingr  r   rr   r!  loaded_param_groupr4   r  pg_idxr   r   s                     r>   _split_optim_state_dictr@  F  s   * E"$H&,eS(%CJ!#J
  $(8H8P$Q$Q$V$V$X$X        ) /J /J"&&& ) 	V 	VE-e4 V V$444 %I.2)+;C+@/ / " "* $tCy2DW2M"N"NNN(,I!E O !%I  !"g.!&$/////c"""& T!%m5Ef5M!N!Ns!SE#J*.%'7'<+ + V V& d49.@.IJJJJ=@C=Q=QTU=U
2&8#9#9:	V'V2 {7#$$))C&*+<>Ns>S&T&T 3 3"tDI'9''BCCDDIIJJ12223xx1}} 1   #C())S1C-D-DDD =   25Z_1E1E1IJr,--.-/?/DEE 	* 	*;44R<<%++-- 	* 	*JCg~~$)HVS!!		* rR   c           	         |j         sd S |D ]M}t          |           |rSt          |v rt          | |||          }n9t	          |t          t          t          t          f         |          |          }ni }|j	        r| 
                                D ];\  }}t          | |          }t          | |d          }	||	k    r/t          |          dk    sJ |                                |	                                |t                   D ]M}
t          t          t          t          f         |
          }fd|t                    D             }||t           <   Nt          t"          |t                             }t%          |                                          D ]2}|v r,|                    |          ||                              <   3=|                                5  t-          j        | ||          }d d d            n# 1 swxY w Y   n-|j        r%d|_        t3          | |f|          }d|_        d fd}t5          t6          j        ||          }J t;          |          \  }}t;          |          \  }}|j        rt?          ||           ntA          ||           |                                D ]"}||vr||v sJ ||         ||<   ||         ||<   #tC          ||          }|t                   D ]:}t           |vr/g t          t          t          t          f         |          t           <   ; tE          |d          |	           Od S )
NF)ri   ro   c                 <    g | ]}|                               S rQ   r*  )rq   r   rr   fqn_with_compilers     r>   r,  z*_load_optim_state_dict.<locals>.<listcomp>  s5       @CC):;;  rR   Tc                     |                                  dk    r$| j        n| j        k    rt          d          | S )Nr   zDevice mismatch)r   r   r   )tr   s    r>   _devicez'_load_optim_state_dict.<locals>._device  sC    5577Q;;~!"18++():;;;rR   r   r   )r   )#rZ   r  r  r@  r%  r   r]   rP   r'   r\   r   r   r}   r   r  r	   r  r(   rc   r   rv   r[   r|   optim_state_dict_to_loadrA   r8  r$   r^   r_   r   rF   r   r   r   r   )rf   r&  r   r   r  r   original_fqn_r   fqns_with_compilerr.  valr4   	osd_stater  r   rF  flatten_osdosd_mappingflatten_local_osdlocal_osd_mapping	optim_keypgr   rr   rC  s                          @@@r>   _load_optim_state_dictrS    sT      TN TN%    
	"###:5*d$ $   $?4S)^ 4jAA4$ $    " B	A $)#9#9#;#; X Xa 55%.<e& & &" ---4yyA~~~~hhjj$6$:$:$<$<!)#. * *AtCH~q11C    GJ7|  F $*CLL 0@0HII	inn..// X XAaxxGP}}UVGWGW	!))C1B"C"CDX ""$$  #'#@5"2$ $                ! %	A#(D 4UUHdKK#'D F     elG5EFFA%%%':;K'L'L$K3FGW3X3X00( V%k3DVTTTTT&{4EfUUUU
 )--// J J	$555$33333>y3I%i03>y3I%i04!#4    's+ A A"$$>@Dc9n-r227;
 	1u/00<LMMMMMiTN TNs   %HH	H	c                    t                      5  t          | dd||          }t          | |          }t          |i |           |cddd           S # 1 swxY w Y   dS )aH  
    Return the model state_dict of ``model``.

    See ``get_state_dict`` for the detail usage.

    Args:
        model (nn.Module): the nn.Module to the model.
        submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters
            that belong to the submodules.
        options (StateDictOptions): the options to control how
            model state_dict and optimizer state_dict should be returned. See
            `StateDictOptions` for the details.

    Returns:
        The state_dict for ``model``.

    :rtype: typing.Dict[str, ValueType]
    rQ   Fr   r   r   N)r?   r   r   r   )rf   r   r   r   r   s        r>   r,   r,     s    0 
 
  
 !
 
 
 1==+R666
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
 s   7AAAc                $   t                      5  t          |t          j        j                  r|fnt          |          }t          | |d||          }t          | ||          }t          i ||           |cddd           S # 1 swxY w Y   dS )a  
    Return the combined state_dict for optimizers.

    See ``get_state_dict`` for the detail usage.

    Args:
        model (nn.Module): the nn.Module to the model.
        optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
            The optimizers that are used to optimize ``model``.
        submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters
            that belong to the submodules.
        options (StateDictOptions): the options to control how
            model state_dict and optimizer state_dict should be returned. See
            `StateDictOptions` for the details.

    Returns:
        The state_dict for ``optimizers``.

    :rtype: OptimizerStateType
    TrU  N)	r?   ry   r^   r  	Optimizertupler   r8  r   )rf   r&  r   r   r   r   s         r>   r-   r-   *  s    6 
     *ek&;<<#ZMMz"" 	
 !
 
 
 1
DII2/666                                   s   A)BB	B	c                H   t                      5  t          |t          j        j                  r|fnt          |          }t          | |d||          }t          | |          }t          | ||          }t          |||           ||fcddd           S # 1 swxY w Y   dS )a  
    Return the model state_dict and optimizers state_dict.

    ``get_state_dict`` can process any module that is parallelized by PyTorch
    FSDP/fully_shard, DDP/replicate, tensor_parallel/parallelize_module, and any
    combination of these parallelisms. The main functions of ``get_state_dict``
    are: 1.) returning a model and optimizer state_dict that can be resharded
    with a different number of trainers and/or different parallelisms.
    2.) hiding the parallelism-specific state_dict APIs. Users don't have to call
    these APIs.
    3.) sanity checking the result state_dict.

    The keys of the result state dictionary are the canonical FQNs (Fully
    Qualified Names).  A canonical FQN refers to the FQN based on a parameter's
    position in an nn.Module hierarchy. More specifically, a canonical FQN to a
    parameter is the FQN returned by ``module.named_parameters()`` or
    ``module.named_buffers()`` when the module is not distributed by any
    parallelisms. Since the optimizer internally uses parameter IDs to represent
    a parameter, there will be a conversion from the parameter IDs to the
    canonical FQNs when calling this API.

    ``get_state_dict`` can also process a module that is not parallelized. In
    such a case, ``get_state_dict`` only performs one function -- converting the
    optimizer parameter IDs to the canonical FQNs.

    Example:
        >>> # xdoctest: +SKIP
        >>> import torch
        >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
        >>> from torch.nn.parallel import DistributedDataParallel as DDP
        >>> from torch.distributed.checkpoint.state_dict import get_state_dict

        >>> fsdp_model = FSDP(copy.deepcopy(model))
        >>> fsdp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)
        >>> ddp_model = DDP(copy.deepcopy(model))
        >>> ddp_optim = torch.optim.Adam(model.parameters(), lr=1e-3)


        >>> ddp_state_dict, ddp_optim_state_dict = get_state_dict(ddp_model, ddp_optim)
        >>> fsdp_state_dict, fsdp_optim_state_dict = get_state_dict(
        ...     fsdp_model, fsdp_optim
        ... )

        >>> # if we simply call ddp_model.state_dict() and fsdp_model.state_dict(),
        >>> # the asserts will fail.
        >>> assert ddp_state_dict == fsdp_state_dict
        >>> assert ddp_optim_state == fsdp_optim_state_dict


    Args:
        model (nn.Module): the nn.Module to the model.
        optimizers (Union[None, Optimizer, Iterable[Optimizer]]):
            The optimizers that are used to optimize ``model``.
        submodules (deprecated): Optional[set[nn.Module]]: only return the model parameters
            that belong to the submodules.
        options (StateDictOptions): the options to control how
            model state_dict and optimizer state_dict should be returned. See
            `StateDictOptions` for the details.

    Returns:
        ``Tuple`` that contain model state_dict and optimizer state_dict.

    :rtype: typing.Tuple[typing.Dict[str, ValueType], OptimizerStateType]
    FrU  N)
r?   ry   r^   r  rW  rX  r   r   r8  r   )rf   r&  r   r   r   r   r   s          r>   r.   r.   W  s   P 
 2 2 *ek&;<<#ZMMz"" 	
 !
 
 
 1==0
DII+-=tDDD!11!2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2s   A;BBBc                   	 |si S t          t          t          |                                                    t          j                  rt          j        dt                     t          t          t          j        t          t          t          f         f         |          }i }|                                D ]\  }}|                                 D ]\  }}||k    rt          | |          }t!          |          dk    s
J d            t          t          |                     d	|                    	fd|                                D                        |S t          t          t          t          f         |          S )NzPassing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``is deprecated and will be removed in 2.5. If you need this feature, please preprocessing the model_state_dict to achieve the same functionality.ro   z/FQNs for a submodule should only have 1 elementrm   c                 "    i | ]\  }}|z   |S rQ   rQ   )rq   subfqnr   rs   s      r>   
<dictcomp>z/_unflatten_model_state_dict.<locals>.<dictcomp>  s#    XXXVf_eXXXrR   )ry   r   r   r   rd   re   r   r   r   r   r]   rP   r'   r   r   r   r}   r   )
rf   r   cast_state_dictr   r   sub_state_dictrg   mr   rs   s
            @r>   _unflatten_model_state_dictra    s     	$tJOO--..//;; 6" 	
 	
 	
 tBItCN/C$CDjQQ/1)8)>)>)@)@ 
	 
	%I~ ..00 	 	a	>> --4yyA~~~'X~~~ d,,///%%XXXXAUAUAWAWXXX   	 Di(*555rR   )r   c                    t          | |          }t                      5  t          | dd|          }t          |i |           t	          | ||          cddd           S # 1 swxY w Y   dS )a=  Load the model state_dict.

    The counterpart of ``get_model_state_dict`` to set the state_dict to the
    model. See ``set_state_dict`` for the detail usage.

    Args:
        model (nn.Module): the nn.Module to the model.
        model_state_dict: (Dict[str, ValueType]):
           the model state_dict to load. If the key of the ``model_state_dict``
           is nn.Module, the key is a submodule of ``model`` and the value should
           be the state_dict of the submodule. When loading the state_dict,
           the prefix of the submodule will be append to the state_dict.
        options (StateDictOptions): the options to control how
            model state_dict and optimizer state_dict should be loaded. See
            `StateDictOptions` for the details.

    Returns:
        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
            * **missing_keys** is a list of str containing the missing keys
            * **unexpected_keys** is a list of str containing the unexpected keys

    :type model_state_dict: typing.Dict[str, ValueType]
    rQ   Fr   r   N)ra  r?   r   r   r  )rf   r   r   r   s       r>   r/   r/     s    : .I. . 
 E EubUGLLL+R666%e-=tDD	E E E E E E E E E E E E E E E E E Es   5A!!A%(A%c                "   t                      5  t          |t          j        j                  r|fnt          |          }t          | |d|          }t          i ||           t          | |||           ddd           dS # 1 swxY w Y   dS )a  Load the optimizers state_dict.

    The counterpart of ``get_optimizer_state_dict`` to set the state_dict to the
    optimizers. See ``set_state_dict`` for the detail usage.

    WARN: ``set_optimizer_state_dict`` can only be called before ``backward()`` or after
        ``step()`` is called on the optimizers. Otherwise, the optimizer states won't be
        initialized correctly.

    Args:
        model (nn.Module): the nn.Module to the model.
        optimizers (Union[Optimizer, Iterable[Optimizer]]):
            The optimizers that are used to optimize ``model``.
        optim_state_dict: OptimizerStateType:
            the optimizer state_dict to load.
        options (StateDictOptions): the options to control how
            model state_dict and optimizer state_dict should be loaded. See
            `StateDictOptions` for the details.

    Returns:
        None

    :type optim_state_dict: typing.OptimizerStateType
    Trc  N)	r?   ry   r^   r  rW  rX  r   r   rS  )rf   r&  r   r   r   s        r>   r0   r0     s    > 
 	J 	J *ek&;<<#ZMMz"" 	
 ujT7SSS2/666uj2BDIII	J 	J 	J 	J 	J 	J 	J 	J 	J 	J 	J 	J 	J 	J 	J 	J 	J 	Js   A(BBBc                d   t          | |          }t                      5  t          |t          j        j                  r|fnt          |          }t          | || |          }t          |||           t          | |||           t          | ||          cddd           S # 1 swxY w Y   dS )a  Load the model state_dict and optimizers state_dict.

    The counterpart of ``get_state_dict`` to set the state_dict to the model and
    optimizers.  The given ``model_state_dict`` and ``optim_state_dict`` do not
    have to be returned by ``get_state_dict`` but must meet the following
    requirements: 1) all FQNs are canonical FQNs as defined in ``get_state_dict``,
    2) if a tensor is sharded, it must be either a ShardedTensor or DTensor,
    3) optimizer state_dict cannot contain the parameter IDs; the keys should be
    the canonical FQNs.

    WARN: ``set_state_dict`` can only be called before ``backward()`` or after ``step()``
        is called on the optimizers. Otherwise, the optimizer states won't be initialized
        correctly.

    Args:
        model (nn.Module): the nn.Module to the model.
        optimizers (Union[Optimizer, Iterable[Optimizer]]):
            The optimizers that are used to optimize ``model``.
        model_state_dict: (Union[Dict[nn.Module, Dict[str, ValueType]], Dict[str, ValueType]]):
           the model state_dict to load. If the key of the ``model_state_dict``
           is nn.Module, the key is a submodule of ``model`` and the value should
           be the state_dict of the submodule. When loading the state_dict,
           the prefix of the submodule will be append to the state_dict.
        optim_state_dict: OptimizerStateType:
            the optimizer state_dict to load.
        options (StateDictOptions): the options to control how
            model state_dict and optimizer state_dict should be loaded. See
            `StateDictOptions` for the details.

    Returns:
        ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
            * **missing_keys** is a list of str containing the missing keys of the model state_dict.
            * **unexpected_keys** is a list of str containing the unexpected keys of the model state_dict.

    :type model_state_dict: typing.Dict[str, ValueType]
    :type optim_state_dict: typing.OptimizerStateType
    rc  N)ra  r?   ry   r^   r  rW  rX  r   r   rS  r  )rf   r&  r   r   r   r   s         r>   r1   r1   %  s-   \ .I. . 
 E E *ek&;<<#ZMMz"" 	
 :.>*>
 
 
 	+-=tDDDuj2BDIII%e-=tDDE E E E E E E E E E E E E E E E E Es   A9B%%B),B)c                F   t          j        t          | |          fd}|| _        t          j        t          | |          dt
          t          t          f         ffd}|| _        t          
                    |           t          
                    |           dS )a  Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model``.

    Patch the ``state_dict`` and ``load_state_dict`` attributes of ``model`` to
    be a partial function to call ``get_state_dict`` and ``set_state_dict``.

    Example:
        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
        from torch.distributed.checkpoint.state_dict import patch_model_state_dict

        model = fsdp(model)
        patch_model_state_dict(model)

    Args:
        model (nn.Module): the nn.Module to the model.
        options (StateDictOptions): the options to control how
            model state_dict and optimizer state_dict should be loaded. See
            `StateDictOptions` for the details.
    Returns:
        None
    )rf   r   c                                    S r8   rQ   _state_dict_calls   r>   state_dict_callz0_patch_model_state_dict.<locals>.state_dict_call      !!!rR   r   c                       |            d S )N)r   rQ   r   _load_state_dict_calls    r>   load_state_dict_callz5_patch_model_state_dict.<locals>.load_state_dict_call      z::::::rR   N)r   r   r,   r   r/   r]   rP   r	   r   r6   r   )rf   r   rj  ro  rn  ri  s       @@r>   _patch_model_state_dictrq  g  s    6 !(  " " " " " 'E%-  ;c3h ; ; ; ; ; ; 1EO,,,011111rR   c                   t          j        t          | ||          fd}t          j        t          | ||          dt          t
          t          f         ffd}t                              |           t                              |           t          |t          j        j                  r|fnt          |          }|D ]}||_        ||_        dS )a  Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers``.

    Patch the ``state_dict`` and ``load_state_dict`` attributes of ``optimizers`` to
    be a partial function to call ``get_state_dict`` and ``set_state_dict``.

    Note that if there are multiple optimizers, all of the optimizers will be patched.
    So users only need to call one of the state_dict() to get the full result.

    Example:
        from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
        from torch.distributed.checkpoint.state_dict import patch_model_state_dict

        model = fsdp(model)
        patch_model_state_dict(model)

    Args:
        model (nn.Module): the nn.Module to the model.
        options (StateDictOptions): the options to control how
            model state_dict and optimizer state_dict should be loaded. See
            `StateDictOptions` for the details.
    Returns:
        None
    )rf   r&  r   c                                    S r8   rQ   rh  s   r>   rj  z4_patch_optimizer_state_dict.<locals>.state_dict_call  rk  rR   r   c                       |            d S )N)r   rQ   rm  s    r>   ro  z9_patch_optimizer_state_dict.<locals>.load_state_dict_call  rp  rR   N)r   r   r-   r0   r]   rP   r	   r6   r   ry   r^   r  rW  rX  r   r   )rf   r&  r   rj  ro  r  rn  ri  s         @@r>   _patch_optimizer_state_dictru    s   > !( 	  " " " " " &- 	  ;c3h ; ; ; ; ; ; O,,,0111 j%+"788	: 
  5 5* 45 5rR   )rH   TT)rH   )qra   r   r9   r   collections.abcr   r   dataclassesr   r   r   	itertoolsr   typingr	   r
   r   r   r   r   r^   torch.distributedr   r   torch.nnrd   'torch.distributed._shard.sharded_tensorr   #torch.distributed._state_dict_utilsr   r   r   r   r   r   ;torch.distributed.algorithms._checkpoint.checkpoint_wrapperr   torch.distributed.fsdpr   r   r   r|   r   r   r   r   r   $torch.distributed.fsdp._common_utilsr   r    torch.distributed.tensorr!   torch.nn.modules.moduler"   torch.nn.parallelr#   rz   torch.utils._pytreer$   __all__r~   r  r  r  r`   rP   r%   r_   r  r  r&   rc   rX  r]   r'   r(   r)   r*   r6   rO   r   r?   r+   rT   re   rN   r   r   r   r  rW  r   r   r   r   no_gradr   r  r  r  r%  r8  r@  rS  r,   r-   r.   ra  r/   r0   r1   rq  ru  rQ   rR   r>   <module>r     s!            				  / / / / / / / / 0 0 0 0 0 0 0 0 0 0       F F F F F F F F F F F F F F F F                    A A A A A A                    	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	 	        - , , , , , 5 5 5 5 5 5 < < < < < < - - - - - -  " 
		Sg}elCKL4&m(<d3CS>TT	 S)^$' #u]4E%EFFG  &)SUU S] * * *    ,. ,. ,. ,. ,. ,. ,. ,.^ @ @ @ @ @% @ @ @& . !%DE DE9DE
DE DE 	DE
 DE DE DE DE DEN	 	 	 	 	 	 	 	%" %" %" %"Z ,0*.  9%+',- 
 RY( &'    D*3	>**(* * 
	* * * *Zbi)>>? c h    S#X&4	#s(^   $ <;9<;*<;	#y.<; <; <; <;~ A
9A
S)^$A
 A
 	A
 A
 A
 A
H'&U[2 '&t '& '& '& '&T=*< =c9nAU = = = =@<; <S)^$< < 	< < < <~ <A9<Aek+S01<A <A 	<A <A <A <A~[9[; [ )[ 	[
 [ [ [ [| ]N9]Nek+S01]N #]N 	]N
 
]N ]N ]N ]NF ,0*.	"  "  " 9"  RY("  &'	" 
 
#y."  "  "  " R ,0*.*  *  * 9* ek+Xek6K-LLM*  RY(	* 
 &'*  *  *  *  * b ,0*.X2 X2 X29X2ek+Xek6K-LLMX2 RY(	X2
 &'X2 4Y!334X2 X2 X2 X2v696d29d3	>&::;T#y.=QQR6 
#y.6 6 6 6J +/	$E $E $E9$E3	>*$E &'	$E
 $E $E $E $EX +/(J (J (J9(Jek+Xek6K-LLM(J )(J
 &'(J 
(J (J (J (Jb +/=E =E =E9=Eek+Xek6K-LLM=E 3	>*	=E
 )=E &'=E =E =E =E =ED  +/12 12 12912 &'12 
	12 12 12 12l 
 +/	;5 ;5 ;59;5 ek+S01;5 &'	;5
 
;5 ;5 ;5 ;5 ;5 ;5rR   