
    &`iq                     6   d dl Z d dlZd dlZd dlZd dlZd dlmZmZmZm	Z	m
Z
mZ d dlZd dlZd dlmZ d dlmZmZ d dlmZ d dlmZ d dlmZmZmZmZmZ d dlmZm Z  d d	l!m"Z"m#Z# d d
l$m%Z% d dl&m'Z' d dl(m)Z)m*Z* d dl+m,Z, d dl-m.Z.m/Z/  ej0        e1          Z2 e/d          dej3        fd            Z4 e/d          de	ej3                 fd            Z5 e/d          	 	 	 d2dej6        j7        dee8ej3        f         de
e9         de
ee9ef                  dej6        j7        f
d            Z: e/d          	 	 	 d3dej;        j<        j        de8de8de8dej;        j<        j        f
d            Z=d  Z>e.d4d"e8ddfd#            Z?e.d$ej@        j        dej@        j        fd%            ZAe.d&ejB        ddfd'            ZC e/d          d5d(eDddfd)            ZEe. G d* d+                      ZF G d, d-e'          ZG G d. d/e          ZH G d0 d1e          ZIdS )6    N)AnyCallableDictListOptionalUnion)Version)
GradScalerautocast)DistributedDataParallel)	Optimizer)
DataLoaderDistributedSamplerIterableDatasetRandomSamplerSequentialSampler)TagKeyrecord_extra_usage_tag)#get_torch_device_manager_by_context'get_torch_device_manager_by_device_type)session)Accelerator)get_acceleratorset_accelerator)_log_deprecation_warning)
Deprecated	PublicAPIstable)	stabilityreturnc                  v    ddl m}  t          t          j        d           |                                 d         S )a  Gets the correct torch device configured for this process.

    Returns the torch device for the current worker. If more than 1 GPU is
    requested per worker, returns the device with the minimal device index.

    .. note::

        If you requested multiple GPUs per worker, and want to get
        the full list of torch devices, please use
        :meth:`~ray.train.torch.get_devices`.

    Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
    superset of the `ray.get_gpu_ids()`.

    Examples:

        Example: Launched 2 workers on the current node, each with 1 GPU

        .. testcode::
            :skipif: True

            os.environ["CUDA_VISIBLE_DEVICES"] = "2,3"
            ray.get_gpu_ids() == [2]
            torch.cuda.is_available() == True
            get_device() == torch.device("cuda:0")

        Example: Launched 4 workers on the current node, each with 1 GPU

        .. testcode::
            :skipif: True

            os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
            ray.get_gpu_ids() == [2]
            torch.cuda.is_available() == True
            get_device() == torch.device("cuda:2")

        Example: Launched 2 workers on the current node, each with 2 GPUs

        .. testcode::
            :skipif: True

            os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
            ray.get_gpu_ids() == [2,3]
            torch.cuda.is_available() == True
            get_device() == torch.device("cuda:2")


        You can move a model to device by:

        .. testcode::
            :skipif: True

            model.to(ray.train.torch.get_device())

        Instead of manually checking the device type:

        .. testcode::
            :skipif: True

            model.to("cuda" if torch.cuda.is_available() else "cpu")
    r   torch_utils1)ray.air._internalr#   r   r   TRAIN_TORCH_GET_DEVICEget_devicesr"   s    t/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/ray/train/torch/train_loop_utils.py
get_devicer)   $   sA    ~ .-----68#>>>""$$Q''    betac                  j    ddl m}  t          t          j        d           |                                 S )a  Gets the correct torch device list configured for this process.

    Assumes that `CUDA_VISIBLE_DEVICES` is set and is a
    superset of the `ray.get_gpu_ids()`.


    Examples:

        Example: Launched 2 workers on the current node, each with 1 GPU

        .. testcode::
            :skipif: True

            os.environ["CUDA_VISIBLE_DEVICES"] == "2,3"
            ray.get_gpu_ids() == [2]
            torch.cuda.is_available() == True
            get_devices() == [torch.device("cuda:0")]

        Example: Launched 4 workers on the current node, each with 1 GPU

        .. testcode::
            :skipif: True

            os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3"
            ray.get_gpu_ids() == [2]
            torch.cuda.is_available() == True
            get_devices() == [torch.device("cuda:2")]

        Example: Launched 2 workers on the current node, each with 2 GPUs

        .. testcode::
            :skipif: True

            os.environ["CUDA_VISIBLE_DEVICES"] == "0,1,2,3"
            ray.get_gpu_ids() == [2,3]
            torch.cuda.is_available() == True
            get_devices() == [torch.device("cuda:2"), torch.device("cuda:3")]
    r   r"   r$   )r%   r#   r   r   TRAIN_TORCH_GET_DEVICESr'   r"   s    r(   r'   r'   i   s<    R .-----693???""$$$r*   Tddpmodelmove_to_deviceparallel_strategyparallel_strategy_kwargsc                 
   |dk    r9t          t          j                  t          d          k     rt          d          t	          t
          j        d           t          t                    	                    | |||          S )a  Prepares the model for distributed execution.

    This allows you to use the same exact code regardless of number of
    workers or the device type being used (CPU, GPU).

    Args:
        model (torch.nn.Module): A torch model to prepare.
        move_to_device: Either a boolean indiciating whether to move
            the model to the correct device or an actual device to
            move the model to. If set to False, the model needs
            to manually be moved to the correct device.
        parallel_strategy ("ddp", "fsdp", or None): Whether to wrap models
            in ``DistributedDataParallel``, ``FullyShardedDataParallel``,
            or neither.
        parallel_strategy_kwargs (Dict[str, Any]): Args to pass into
            ``DistributedDataParallel`` or ``FullyShardedDataParallel``
            initialization if ``parallel_strategy`` is set to "ddp"
            or "fsdp", respectively.
    fsdpz1.11.0zsFullyShardedDataParallel requires torch>=1.11.0. Run `pip install 'torch>=1.11.0'` to use FullyShardedDataParallel.r$   )r0   r1   r2   )
r	   torch__version__ImportErrorr   r   TRAIN_TORCH_PREPARE_MODELr   _TorchAcceleratorprepare_model)r/   r0   r1   r2   s       r(   r:   r:      s    4 F""wu/@'A'AGHDUDU'U'UQ
 
 	

 6;SAAA,--;;%+!9	 <   r*   data_loaderadd_dist_samplerauto_transferc                     t          t          j        d           t          t                                        | |||          S )a	  Prepares :class:`~torch.utils.data.DataLoader` for distributed execution.

    This allows you to use the same exact code regardless of number of
    workers or the device type being used (CPU, GPU).

    .. note::

        This method adds a `DistributedSampler` to the `DataLoader` if the
        number of training workers is greater than 1. If shuffling is
        enabled on the original `DataLoader`, then `shuffle=True` will also
        be passed into the `DistributedSampler` constructor. `shuffle=False`
        on the original `DataLoader` also means that shuffling is disabled
        on the sampler.

        With more than 1 worker, calling the `DistributedSampler.set_epoch` method
        at the beginning of each epoch before creating the DataLoader iterator
        is necessary to make shuffling work properly across multiple epochs.
        Otherwise, the same ordering will be always used.
        See: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler  # noqa: E501

    Example:

    .. testcode:
        :skipif: True

        import torch

        import ray.train.torch

        train_dataloader = torch.utils.data.DataLoader(
            ..., batch_size=..., shuffle=True
        )
        train_dataloader = ray.train.torch.prepare_data_loader(train_loader)

        for epoch in range(10):
            if ray.train.get_context().get_world_size() > 1:
                # Required for the distributed sampler to shuffle properly across epochs
                train_dataloader.sampler.set_epoch(epoch)

            for X, y in train_loader:
                # No need to move data to GPU, this is done by `prepare_data_loader`!
                # X, y = X.to("cuda"), y.to("cuda")
                ...

    Args:
        data_loader (torch.utils.data.DataLoader): The DataLoader to
            prepare.
        add_dist_sampler: Whether to add a DistributedSampler to
            the provided DataLoader.
        move_to_device: If set, automatically move the data
            returned by the data loader to the correct device.
        auto_transfer: If set and device is GPU, another CUDA stream
            is created to automatically copy data from host (CPU) memory
            to device (GPU) memory (the default CUDA stream still runs the
            training procedure). If device is CPU, it will be disabled
            regardless of the setting. This configuration will be ignored
            if ``move_to_device`` is False.
    r$   )r<   r0   r=   )r   r   TRAIN_TORCH_PREPARE_DATALOADERr   r9   prepare_data_loader)r;   r<   r0   r=   s       r(   r@   r@      sK    B 6@#FFF,--AA)%#	 B   r*   c                  0    ddl m}  t          |            d S )Nr   _TORCH_AMP_DEPRECATION_MESSAGE)#ray.train.v2.torch.train_loop_utilsrC   r   rB   s    r(   _log_amp_deprecation_warningrE     s)    RRRRRR;<<<<<r*   Fampc                     t                       	 t          t          |                      dS # t          $ r t          d          w xY w)aQ  [Deprecated] Enables training optimizations.

    Arguments:
        amp: If true, perform training with automatic mixed precision.
            Otherwise, use full precision.

    .. warning:: ``train.torch.accelerate`` cannot be called more than once, and it
       must be called before any other ``train.torch`` utility function.
    rF   zAn accelerator has already been set. Make sure `train.torch.accelerate()` is not called multiple times, and is called before any of the prepare methods.N)rE   r   r9   RuntimeErrorrH   s    r(   
acceleraterJ     sg     !"""
)c22233333 
 
 
1
 
 	

s	   / A		optimizerc                 l    t                       t          t                                        |           S )z[Deprecated] Wraps optimizer to support automatic mixed precision.

    Args:
        optimizer (torch.optim.Optimizer): The DataLoader to prepare.

    Returns:
        A wrapped optimizer.
    )rE   r   r9   prepare_optimizer)rK   s    r(   rM   rM   (  s-     !""",--??	JJJr*   tensorc                 p    t                       t          t                                        |            dS )z[Deprecated] Computes the gradient of the specified tensor w.r.t. graph leaves.

    Args:
        tensor (torch.Tensor): Tensor of which the derivative will be computed.
    N)rE   r   r9   backward)rN   s    r(   rP   rP   6  s3     !"""%&&//77777r*   seedc                 T    t          t                                        |            dS )at  Limits sources of nondeterministic behavior.

    This function:

        * Seeds PyTorch, Python, and NumPy.
        * Disables CUDA convolution benchmarking.
        * Configures PyTorch to use determinstic algorithms.
        * Seeds workers spawned for multi-process data loading.

    Args:
        seed: The number to seed libraries and data workers with.

    .. warning:: ``train.torch.enable_reproducibility()`` can't guarantee
        completely reproducible results across executions. To learn more, read
        the `PyTorch notes on randomness
        <https://pytorch.org/docs/stable/notes/randomness.html>`_.
    N)r   r9   enable_reproducibility)rQ   s    r(   rS   rS   A  s'    & %&&==dCCCCCr*   c                   0    e Zd ZdZdZddee         fdZdS )TorchWorkerProfilerzUtility class for running PyTorch Profiler on a Train worker.

    Args:
        trace_dir (Optional[str]): The directory to store traces on the
           worker node. If ``None``, this will use a default temporary dir.
    pytorch_profiler_worker_tracesN	trace_dirc                      t          d          )NzGThe `ray.train.torch.TorchWorkerProfiler` API is deprecated in Ray 2.0.)DeprecationWarning)selfrW   s     r(   __init__zTorchWorkerProfiler.__init__b  s     U
 
 	
r*   N)__name__
__module____qualname____doc__WORKER_TRACE_DIR_NAMEr   strr[    r*   r(   rU   rU   W  sJ          =
 
(3- 
 
 
 
 
 
r*   rU   c                   0   e Zd ZdZddefdZ	 	 	 ddej        j        d	ed
e	e
         de	ee
ef                  dej        j        f
dZ	 	 	 ddej        j        j        ded	ededej        j        j        f
dZdedefdZdej        ddfdZddeddfdZdS )r9   zA utility that implements methods to accelerate PyTorch training.

    Arguments:
        amp: If true, perform training with automatic mixed precision.
            Otherwise, use full precision.
    FrF   c                 v    || _         |rt                      nd | _        d | _        t	                      | _        d S r\   )amp_is_enabledr
   scaler_seedr   device_manager)rZ   rF   s     r(   r[   z_TorchAccelerator.__init__p  s9    !&)3jlllt
ACCr*   Tr.   Nr/   r0   r1   r2   r    c                 t   |pi }t          j                    }t          |t          j                  r|}n+t                      }t          |t                    r|d         }| j                                        r| j        	                    |           |rV|dk    rt                              d|            nt                              d|            |                    |          }d }| j        rc|j        |_         t#                      |j                  |_        t%          |d          r|j        |_        t+          j        ||          |_        t          j                    }|r|dk    r|dk    r4t0          }	| j                                        r|j        dk    r|g|d|}n5t          j                                        st7          d	          dd
lm}
 |
}	|dk    r$t                              d|	j         d           n#t                              d|	j         d            |	|fi |}|S )a  Prepares the model for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

        Args:
            model (torch.nn.Module): A torch model to prepare.
            move_to_device: Whether to move the model to the correct
                device. If set to False, the model needs to manually be moved
                to the correct device.
            parallel_strategy ("ddp", "fsdp", or None): Whether to wrap models
                in ``DistributedDataParallel``, ``FullyShardedDataParallel`` (
                Experimental), or neither.
            parallel_strategy_kwargs (Dict[str, Any]): Args to pass into
                ``DistributedDataParallel`` or ``FullyShardedDataParallel``
                initialization if ``parallel_strategy`` is set to "ddp"
                or "fsdp", respectively.
        r   zMoving model to device: c                     t          | d          r#|                                 }|d         |d<   |d= n| j                                        }|d= |d         |d<   |d= |S )N_original_get_state__getstate___unwrapped_forwardforward)hasattrrl   __dict__copyrZ   states     r(   model_get_statez8_TorchAccelerator.prepare_model.<locals>.model_get_state  s     t233 *0022(-.C(Dn%/00 **,,.)$%9:E)*+Lr*   rm      r.   cpu)
device_idsoutput_devicezhFSDP is only available with GPU-enabled training. Set `use_gpu=True` in your Trainer to train with GPUs.)FullyShardedDataParallelzWrapping provided model in .)r   get_local_rank
isinstancer5   devicer)   listri   is_available
set_deviceloggerinfodebugtorf   ro   rn   r   rp   rm   rl   types
MethodTypeget_world_sizer   typecudarI   torch.distributed.fsdprz   r]   )rZ   r/   r0   r1   r2   rankr~   ru   
world_sizeDataParallelrz   s              r(   r:   z_TorchAccelerator.prepare_modelv  s   2 $<#Ar %''nel33 	##FF\\F&$'' #++-- 	3**6222 	%qyy?v??@@@@@@@AAAHHV$$E	 	 	$  	J (-}E$&HJJu}55EM un-- ?,1,>) "'!1/5!I!IE+--
 	Da E))6&3355 &+:N:N'-h)/0 0 30, z..00 &    LKKKKK7qyyR,:ORRRSSSSS<;PSSSTTT LCC*BCCEr*   r;   r<   r=   c                 R    t          j                    }t          j                    |dk    rWt          |j        t
                    s=t          |d          rt          |j        t                    s|r fd} ||          }|rt                      }t          |||          }|S )a  Prepares DataLoader for distributed execution.

        This allows you to use the same exact code regardless of number of
        workers or the device type being used (CPU, GPU).

        Args:
            data_loader (torch.utils.data.DataLoader): The DataLoader to
                prepare.
            add_dist_sampler: Whether to add a DistributedSampler to
                the provided DataLoader.
            move_to_device: If set, automatically move the data
                returned by the data loader to the correct device.
            auto_transfer: (Experimental) If set and device is GPU, another CUDA stream
                is created to automatically copy data from host (CPU) memory
                to device (GPU) memory (the default CUDA stream still runs the
                training procedure). If device is CPU, it will be disabled
                regardless of the setting. This configuration will be ignored
                if ``move_to_device`` is False.
        rv   datasetc                 j   t          | j        t                     }dt          t          t
          gd f                  fd}| j        }| j        }j        8 ||          }t          j
                    }|                    j                   t          | j        t          t          f          }|s3dk    r-t                              d| j        j        j         d           | j        | j        d| j        | j        | j        | j        | j        ||t1          | j        |          d}t3          d	i |S )
Nworker_init_fnc                 "     dt           f fd}|S )N	worker_idc                     t          j                    dz  }t          j                            |           t          j        |           r |            d S d S )Nl        )r5   initial_seednprandomrQ   )r   worker_seedr   s     r(   wrapperzk_TorchAccelerator.prepare_data_loader.<locals>.with_sampler.<locals>.seeded_worker_init_fn.<locals>.wrapper"  sb    &+&8&:&:U&B	{333K000) 6*N9555556 6r*   )int)r   r   s   ` r(   seeded_worker_init_fnzZ_TorchAccelerator.prepare_data_loader.<locals>.with_sampler.<locals>.seeded_worker_init_fn  s/    63 6 6 6 6 6 6 #Nr*   r   zThe z will be overwritten with a DistributedSampler. You can disable this by setting `with_sampler` to False in `prepare_data_loader`.F)shuffle)r   
batch_sizer   num_workers
collate_fn
pin_memory	drop_lasttimeoutr   	generatorsamplerrc   )r}   r   r   r   r   r   r   r   rh   r5   	Generatormanual_seedr   r   warning	__class__r]   r   r   r   r   r   r   r   r   r   )	loaderr   r   r   r   using_default_samplerdata_loader_argsrZ   
world_ranks	          r(   with_samplerz;_TorchAccelerator.prepare_data_loader.<locals>.with_sampler  se    )9JKKK
#$,XseTk-B$C
# 
# 
# 
# CIBW7=7G	:)%:%:>%J%JN % 1 1I))$*555(2N%6$F) )% - qNNLv~7@ L L L    &~"("3$#)#5"("3"("3!'!1%~&4!*1&.'RRR$ $  "55$4555r*   )r   r   get_world_rankr}   r   r   rp   r   r   r)   _WrappedDataLoader)	rZ   r;   r<   r0   r=   r   r   r~   r   s	   `       @r(   r@   z%_TorchAccelerator.prepare_data_loader  s    6 +--
+--
 NN{24FGG  Y//  {2ODD	  ! 86 86 86 86 86 86t ',{33K 	Q\\F,[&-PPKr*   rK   c                 .    t          || j                  S )zWraps optimizer to support automatic mixed precision.

        Args:
            optimizer (torch.optim.Optimizer): The DataLoader to prepare.

        Returns:
            A wrapped optimizer.
        )rg   )_WrappedOptimizerrg   )rZ   rK   s     r(   rM   z#_TorchAccelerator.prepare_optimizerS  s     !4;????r*   rN   c                     | j         r.| j                            |                                           dS |                                 dS )zComputes the gradient of the specified tensor w.r.t. graph leaves.

        Args:
            tensor (torch.Tensor): Tensor of which the derivative will be computed.
        N)rf   rg   scalerP   )rZ   rN   s     r(   rP   z_TorchAccelerator.backward^  sM      	Kf%%..00000OOr*   r   rQ   c                    || _         t          j        |           t          j        |           t
          j                            |           t          j        d           dt          j        j        _	        dt          j        d<   dS )z,Limits sources of nondeterministic behavior.TFz:4096:8CUBLAS_WORKSPACE_CONFIGN)rh   r5   r   r   rQ   r   use_deterministic_algorithmsbackendscudnn	benchmarkosenviron)rZ   rQ   s     r(   rS   z(_TorchAccelerator.enable_reproducibilityi  sq    
$D
	t*4000).&
 1:
,---r*   FTr.   N)TTFr   )r]   r^   r_   r`   boolr[   r5   nnModuler   rb   r   r   r:   utilsdatar   r@   r   rM   TensorrP   r   rS   rc   r*   r(   r9   r9   h  s        D DD D D D D  $+0=Al lxl l $C=	l
 #+4S>":l 
l l l lb "&##m m[%0m m 	m
 m 
		$m m m m^	@9 	@ 	@ 	@ 	@ 	@	u| 	 	 	 	 	: :3 :t : : : : : :r*   r9   c                   P    e Zd Zdedej        defdZd Zd Z	d Z
d Zd	 Zd
 ZdS )r   base_dataloaderr~   r=   c                    | j                             t          |di                      || _        d | _        || _        t          |j                  | _        |j        dk    r!| j        	                                r|| _
        nd| _
        |j        dk    r!| j
        r| j                            |          nd | _        d | _        d S )Nrq   rw   F)rq   updategetattr_dataloaderdataloader_iterr~   r   r   ri   supports_stream_auto_transfercreate_stream_memcpy_stream
next_batch)rZ   r   r~   r=   s       r(   r[   z_WrappedDataLoader.__init__{  s     	W_j"EEFFF*#EfkRR ;%D$7$G$G$I$I"/D"'D {e##(;# --f555 	
 r*   c                 \    |d S  fd} j                              j                  5  t          |t          j        j                  r! fd|                                D             }nt          |t                    rt           fd|D                       }nwt          |t                    r fd|D             }nSt          |t          j                  r ||          }n-t                              dt          |           d           |}|cd d d            S # 1 swxY w Y   d S )Nc                     	 |                      j        j                  } n6# t          $ r) t                              d|  dj         d           Y nw xY w| S )N)non_blockingzItem z cannot be moved to device r{   )r   r~   r   AttributeErrorr   r   )irZ   s    r(   try_move_devicez;_WrappedDataLoader._move_to_device.<locals>.try_move_device  sw    WDD43FDGG! W W WUQUUt{UUUVVVVVWHs   !% 0AAc                 B    i | ]\  }}|                     |          S rc   _move_to_device).0kvrZ   s      r(   
<dictcomp>z6_WrappedDataLoader._move_to_device.<locals>.<dictcomp>  s-    !V!V!VA!T%9%9!%<%<!V!V!Vr*   c              3   B   K   | ]}                     |          V  d S r\   r   r   r   rZ   s     r(   	<genexpr>z5_WrappedDataLoader._move_to_device.<locals>.<genexpr>  s1      &M&M1t';';A'>'>&M&M&M&M&M&Mr*   c                 :    g | ]}                     |          S rc   r   r   s     r(   
<listcomp>z6_WrappedDataLoader._move_to_device.<locals>.<listcomp>  s'    !H!H!Ha$"6"6q"9"9!H!H!Hr*   z
Data type z' doesn't support being moved to device.)ri   get_stream_contextr   r}   collectionsabcMappingitemstupler   r5   r   r   r   r   )rZ   itemr   item_on_devices   `   r(   r   z"_WrappedDataLoader._move_to_device  s   <4	 	 	 	 	  33D4GHH 	" 	"$ 788 &!V!V!V!V!V!V!VD%(( 
&!&&M&M&M&M&M&M&M!M!MD$'' &!H!H!H!H4!H!H!HD%,// &!0!6!6TdTTT   "&!	" 	" 	" 	" 	" 	" 	" 	" 	" 	" 	" 	" 	" 	" 	" 	" 	" 	"s   C*D!!D%(D%c                     | j         d S | j                                        }|                    | j                    |D ](}	 |                    |           # t
          $ r Y %w xY wd S r\   )r   ri   get_current_streamwait_streamrecord_streamr   )rZ   r   curr_streamr   s       r(   _wait_for_batchz"_WrappedDataLoader._wait_for_batch  s    &F
 )<<>> 3444  	 	A,,,,!   	 	s   A
A%$A%c                 *    t          | j                  S r\   )lenr   rZ   s    r(   __len__z_WrappedDataLoader.__len__  s    4#$$$r*   c                 d    t          | j        d           }|                     |          | _        d S r\   )nextr   r   r   rZ   r   s     r(   _prefetch_next_batchz'_WrappedDataLoader._prefetch_next_batch  s-    $.55
..z::r*   c                 `    t          | j                  | _        |                                  | S r\   )iterr   r   r   r   s    r(   __iter__z_WrappedDataLoader.__iter__  s,    #D$455!!###r*   c                 x    | j         }|t          |                     |           |                                  |S r\   )r   StopIterationr   r   r   s     r(   __next__z_WrappedDataLoader.__next__  sA    _
Z(((!!###r*   N)r]   r^   r_   r   r5   r~   r   r[   r   r   r   r   r   r   rc   r*   r(   r   r   z  s        )38<PT   ." " "8  .% % %; ; ;  
    r*   r   c                       e Zd Zddedee         fdZed             Zej	        d             Zed             Z
e
j	        d             Z
ed	             Zej	        d
             Zd Zd Zd Zd ZddZdS )r   NrK   rg   c                 "    || _         || _        d S r\   )rK   rg   )rZ   rK   rg   s      r(   r[   z_WrappedOptimizer.__init__  s    "r*   c                     | j         j        S r\   rK   rt   r   s    r(   rt   z_WrappedOptimizer.state  s    ~##r*   c                     || j         _        d S r\   r  rs   s     r(   rt   z_WrappedOptimizer.state  s    $r*   c                     | j         j        S r\   rK   param_groupsr   s    r(   r  z_WrappedOptimizer.param_groups  s    ~**r*   c                     || j         _        d S r\   r  )rZ   r  s     r(   r  z_WrappedOptimizer.param_groups  s    &2###r*   c                     | j         j        S r\   rK   defaultsr   s    r(   r  z_WrappedOptimizer.defaults  s    ~&&r*   c                     || j         _        d S r\   r
  )rZ   r  s     r(   r  z_WrappedOptimizer.defaults  s    "*r*   c                 :    | j                             |           d S r\   )rK   add_param_group)rZ   param_groups     r(   r  z!_WrappedOptimizer.add_param_group  s    &&{33333r*   c                 :    | j                             |           d S r\   )rK   load_state_dict)rZ   
state_dicts     r(   r  z!_WrappedOptimizer.load_state_dict  s    &&z22222r*   c                 4    | j                                         S r\   )rK   r  r   s    r(   r  z_WrappedOptimizer.state_dict  s    ~((***r*   c                 8    | j                                          d S r\   )rK   	zero_gradr   s    r(   r  z_WrappedOptimizer.zero_grad   s      """""r*   c                     | j         ;| j                             | j        |           | j                                          d S | j                            |           d S r\   )rg   steprK   r   )rZ   closures     r(   r  z_WrappedOptimizer.step  sZ    ;"KT^W555K     N(((((r*   r\   )r]   r^   r_   r   r   r
   r[   propertyrt   setterr  r  r  r  r  r  r  rc   r*   r(   r   r     s:        ) Xj5I     $ $ X$ \% % \% + + X+ 3 3 3 ' ' X' _+ + _+4 4 43 3 3+ + +# # #) ) ) ) ) )r*   r   r   )TTTr   r   )Jr   loggingr   r   r   typingr   r   r   r   r   r   numpyr   r5   packaging.versionr	   torch.cuda.ampr
   r   torch.nn.parallelr   torch.optimr   torch.utils.datar   r   r   r   r   ray._common.usage.usage_libr   r    ray.air._internal.device_managerr   r   ray.train._internalr   ray.train._internal.acceleratorr   ray.train._internal.sessionr   r   ray.train.utilsr   ray.util.annotationsr   r   	getLoggerr]   r   r~   r)   r'   r   r   r   rb   r:   r   r   r@   rE   rJ   optimrM   r   rP   r   rS   rU   r9   r   r   rc   r*   r(   <module>r,     s$        				   = = = = = = = = = = = = = = = =      % % % % % % / / / / / / / / 5 5 5 5 5 5 ! ! ! ! ! !              G F F F F F F F        ( ' ' ' ' ' 7 7 7 7 7 7 H H H H H H H H 4 4 4 4 4 4 6 6 6 6 6 6 6 6		8	$	$ XA(EL A( A( A( A(H V+%T%,' +% +% +% +%\ X 15',9=	% %8?%$,-%  }% 'tCH~6	%
 X_% % % %P X "	F F!,FF F 	F
 [ F F F FR= = = 
 
D 
T 
 
 
 
* 
K!6 
K5;;P 
K 
K 
K 
K 8U\ 8d 8 8 8 8 XD D DT D D D D* 
 
 
 
 
 
 
 
 O: O: O: O: O: O: O: O:d] ] ] ] ] ] ] ]@.) .) .) .) .)	 .) .) .) .) .)r*   