
    `iPl                        d Z ddlmZ ddlmZmZ ddlmZmZm	Z	m
Z
 ddlZddlmc mZ ddlZddlZddlmZ ddlmZ ddZddZ G d de          Z G d de          ZdS )z
Base classes for cuDNN API wrappers.

This module provides abstract base classes that define common interfaces
for cuDNN API wrapper classes, including validation, compilation, and execution patterns.
    )annotations)ABCabstractmethod)AnyListTupleOptionalN)_convert_to_cutlass_data_typeaintbreturnc                    | |z   dz
  |z  S N    )r   r   s     b/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/cudnn/api_base.pyceil_divr      s    EAI!    nboolc                &    | dk    o| | dz
  z  dk    S )zCheck if n is a power of 2.r   r   r   )r   s    r   is_power_of_2r      s    q5'a1q5ka''r   c                     e Zd ZdZd ZedBd            ZedCdDd
            ZeddddEd            ZdFdZ	dGdZ
dHdZdIdZdIdZdJdZdJd ZdKdLd#Z	 dKdMd%Z	 dKdMd&Z	 dKdNd+Z	 	 	 	 dOdPd2Z	 	 dQdRd6ZdSd9ZdSd:ZdSd;ZdTdUd?Z	 dVdWdAZdS )XAPIBaseaM  Abstract base class for cuDNN API wrappers.

    This class defines the common interface that all API wrapper implementations
    should follow, including configuration validation, compilation, and execution.

    Provides common functionality:
    - Logging via self._logger
    - Support validation tracking via self._is_supported
    - Compiled kernel caching via self._compiled_kernel
    - Stream management helpers

    Subclasses should implement the abstract methods to provide
    API-specific validation logic and execution behavior.

    Example:
        >>> class MyKernelAPI(APIBase):
        ...     def __init__(self, sample_input, sample_output, config):
        ...         super().__init__()
        ...         self.sample_input = sample_input
        ...         self.sample_output = sample_output
        ...         self.config = config
        ...         self._kernel = MyKernel
        ...
        ...     def check_support(self) -> bool:
        ...         # Validate inputs and configuration
        ...         assert self.sample_input.dtype == torch.float32
        ...         self._is_supported = True
        ...         return True
        ...
        ...     def compile(self, current_stream=None):
        ...         current_stream = self._get_default_stream(current_stream)
        ...         self._ensure_support_checked()
        ...         # Create and compile kernel
        ...         kernel = self._kernel(self.config)
        ...         self._compiled_kernel = cute.compile(kernel, ...)
        ...
        ...     def execute(self, input_tensor, output_tensor,
        ...                current_stream=None, skip_compile=False):
        ...         current_stream = self._get_default_stream(current_stream)
        ...         if not skip_compile:
        ...             self._compiled_kernel(input_tensor, output_tensor, current_stream)
        ...         else:
        ...             # Direct execution without cached compilation
        ...             kernel = self._kernel(self.config)
        ...             kernel(input_tensor, output_tensor, current_stream)
    c                    d| _         d| _        d| _        d| _        t	          j        | j        j                  | _        dS )a  Initialize the API base.

        Sets up:
        - self._is_supported: Flag indicating if configuration is validated
        - self._kernel: Kernel instance
        - self._compiled_kernel: Cache for compiled kernel
        - self._logger: Logger instance for this class
        FN)	_is_supported_kernel_compiled_kernel_interpret_uint8_as_fp4x2logging	getLogger	__class____name___loggerselfs    r   __init__zAPIBase.__init__Q   s>     # $).&()@AAr   r   r   c                    dS )a  Check if the current configuration is supported by the kernel.

        This method should validate:
        - Input/output tensor shapes and strides
        - Data types compatibility
        - Hardware capabilities (compute capability, memory, etc.)
        - Configuration parameters (tile sizes, cluster shapes, etc.)

        Implementations should set self._is_supported = True if valid.

        :return: True if the configuration is supported
        :rtype: bool
        :raises AssertionError: If a configuration requirement is not met

        Example:
            >>> def check_support(self) -> bool:
            ...     self._logger.debug("Checking support")
            ...     assert self.input.dtype in {torch.float16, torch.float32}
            ...     assert self.input.shape[0] % 16 == 0, "Shape must be 16-aligned"
            ...     self._is_supported = True
            ...     return True
        Nr   r&   s    r   check_supportzAPIBase.check_support`   s	    0 	r   Ncurrent_streamOptional[cuda.CUstream]Nonec                    dS )aZ  Compile the kernel with the current configuration.

        This method should:
        1. Ensure support has been checked (use self._ensure_support_checked())
        2. Get default stream if needed (use self._get_default_stream())
        3. Create the underlying kernel implementation
        4. Compile the kernel using cute.compile()
        5. Cache the compiled kernel in self._compiled_kernel

        :param current_stream: CUDA stream for compilation (optional)
        :type current_stream: cuda.CUstream or None
        :raises AssertionError: If the configuration is not supported

        Example:
            >>> def compile(self, current_stream=None):
            ...     current_stream = self._get_default_stream(current_stream)
            ...     self._ensure_support_checked()
            ...
            ...     kernel = self._kernel(self.config)
            ...     self._compiled_kernel = cute.compile(
            ...         kernel,
            ...         self.sample_input,
            ...         self.sample_output,
            ...         current_stream
            ...     )
        Nr   )r'   r+   s     r   compilezAPIBase.compilez   s	    8 	r   F)r+   skip_compiler0   r   c                   dS )ak  Execute the kernel with the provided inputs.

        This method should handle two execution modes:
        1. With compiled kernel (skip_compile=False): Use self._compiled_kernel
        2. Without compiled kernel (skip_compile=True): Create and execute kernel directly (JIT)

        :param args: Positional arguments (typically input/output tensors)
        :param current_stream: CUDA stream for execution (optional)
        :type current_stream: cuda.CUstream or None
        :param skip_compile: If False, use cached compiled kernel;
                            If True, create and execute kernel directly
        :type skip_compile: bool
        :param kwargs: Additional keyword arguments for execution
        :return: Execution result (if any)
        :raises AssertionError: If compiled kernel is not available when skip_compile=False

        Example:
            >>> def execute(self, input_tensor, output_tensor,
            ...            current_stream=None, skip_compile=False):
            ...     current_stream = self._get_default_stream(current_stream)
            ...
            ...     if not skip_compile:
            ...         assert self._compiled_kernel is not None, "Kernel not compiled"
            ...         self._logger.debug("Executing with compiled kernel")
            ...         self._compiled_kernel(input_tensor, output_tensor, current_stream)
            ...     else:
            ...         self._logger.debug("Executing without compiled kernel (JIT)")
            ...         kernel = self._kernel(self.config)
            ...         kernel(input_tensor, output_tensor, current_stream)
        Nr   )r'   r+   r0   argskwargss        r   executezAPIBase.execute   s
    L 	r   c                      | j         |ddi|S )a  Convenience method to execute the kernel.

        This is a shorthand for calling execute() with skip_compile=True,
        which bypasses the cached compiled kernel and executes directly.
        This is useful for one-off executions or when you want to ensure
        fresh compilation.

        :param args: Positional arguments passed to execute()
        :param kwargs: Keyword arguments passed to execute()
        :return: Result from execute()

        Example:
            >>> api = MyKernelAPI(...)
            >>> api.check_support()
            >>> # Direct execution without pre-compilation
            >>> api(input_tensor, output_tensor)  # Equivalent to execute(..., skip_compile=True)
        r0   T)r4   )r'   r2   r3   s      r   __call__zAPIBase.__call__   s     $ t|T?????r   c                    | j         sE| j                            | j        j         d           |                                 sJ d            dS dS )a  Helper to ensure check_support() was called before compilation.

        If check_support() has not been called yet (self._is_supported is False),
        this method will automatically call it. This prevents compilation
        with invalid configurations.

        :raises AssertionError: If check_support() returns False or raises

        Example:
            >>> def compile(self, current_stream=None):
            ...     self._ensure_support_checked()  # Automatic validation
            ...     # ... rest of compilation
        z2: check_support not previously called, calling nowzUnsupported configurationN)r   r%   infor#   r$   r*   r&   s    r   _ensure_support_checkedzAPIBase._ensure_support_checked   sg     ! 	EL!8lllmmm%%''DD)DDD'	E 	EDDr   streamcuda.CUstreamc                    |E| j                             | j        j         d           t          j                                        S |S )a  Get default CUDA stream if none provided.

        This is a convenience helper to handle optional stream parameters.
        If a stream is provided, it is returned as-is. If None, the default
        CUDA stream is returned.

        :param stream: CUDA stream or None
        :type stream: cuda.CUstream or None
        :return: CUDA stream (either provided or default)
        :rtype: cuda.CUstream

        Example:
            >>> def compile(self, current_stream=None):
            ...     current_stream = self._get_default_stream(current_stream)
            ...     # Now current_stream is guaranteed to be a valid stream
        Nz/: No CUDA stream provided, using default stream)r%   debugr#   r$   cutlasscudadefault_stream)r'   r:   s     r   _get_default_streamzAPIBase._get_default_stream   sF    " >L$."9jjjkkk<..000r   tensorOptional[torch.Tensor]ndimr   namestrc                    |b|j         |k     rW| j                            d| d| d|j                    t	          ||j         z
            D ]}|                    d          }|S )a  Pad a tensor by unsqueezing at dim -1 until it reaches ndim rank.

        - If tensor is None, returns None.
        - Unsqueezes at dim -1 until tensor.ndim == ndim.
        - Logs final reshape for traceability.

        :param tensor: The tensor to pad (or None)
        :param ndim: Target rank (pad trailing dims until reached)
        :param name: Logical tensor name for logging
        :return: The padded tensor (or None)
        NzPadding  to zD from )rD   r%   r8   shaperange	unsqueezer'   rB   rD   rE   _s        r   _pad_tensor_to_ndimzAPIBase._pad_tensor_to_ndim   s    " V[4%7%7LNNN4NNNNOOO4&+-.. . .))"--r   c           	     8   ||j         |k    r| j                            d| d|j         d| d           t	          |j         |z
            D ]}|                    d          }|j         |k    r)| j                            d| d|j         d| d           |S )	a  Unpad a tensor by squeezing at dim -1 until it reaches ndim rank.

        - If tensor is None, returns None.
        - Squeezes at dim -1 until tensor.ndim == ndim.
        - Logs final reshape for traceability.

        :param tensor: The tensor to unpad (or None)
        :param ndim: Target rank (squeeze trailing dims until reached)
        :param name: Logical tensor name for logging
        :return: The unpadded tensor (or None)
        Nz
Unpadding z from rH   DrI   z resulted in shape z, expected )rD   r%   r8   rJ   rK   squeezecriticalrM   s        r   _unpad_tensor_to_ndimzAPIBase._unpad_tensor_to_ndim  s    " V[4%7%7LP4PPv|PPPPPQQQ6;-.. , ,++{d""%%&l4&l&lFL&l&lei&l&l&lmmmr   tensor_or_dtypetorch.Tensor | torch.dtypec                    |dS t          |t          j                  r|j        n|}|t          j        k    p| j        o|t          j        k    S )a  Check if tensor or dtype is an FP4x2 packed datatype.

        :param tensor_or_dtype: The torch tensor or dtype to check
        :type tensor_or_dtype: torch.Tensor | torch.dtype
        :return: True if tensor/dtype is an FP4x2 packed type
        :rtype: bool
        NF)
isinstancetorchTensordtypefloat4_e2m1fn_x2r    uint8r'   rU   r[   s      r   	_is_fp4x2zAPIBase._is_fp4x2,  sS     "5)3OU\)R)Rg%%Xg//mT5S5lX]afalXlmr   c                    |dS t          |t          j                  r|j        n|}|t          j        t          j        hv S )zCheck if tensor or dtype is an FP8 datatype.

        :param tensor_or_dtype: The torch tensor or dtype to check
        :type tensor_or_dtype: torch.Tensor | torch.dtype
        :return: True if tensor/dtype is an FP8 type
        :rtype: bool
        NF)rX   rY   rZ   r[   float8_e5m2float8_e4m3fnr^   s      r   _is_fp8zAPIBase._is_fp89  sD     "5)3OU\)R)Rg%%Xg*E,?@@@r    torch.Tensorc           	     R   t          d t          |                                          D             d          }|k| j                            d| d|j         d|                                 d           t          d| d|j         d|                                 d          |S )zReturn index of innermost contiguous dimension (stride == 1).

        :raises RuntimeError: If no dimension with stride 1 is found.
        c              3  ,   K   | ]\  }}|d k    |V  dS )r   Nr   .0iss      r   	<genexpr>z4APIBase._get_innermost_stride_dim.<locals>.<genexpr>K  s*      GG$!QQAGGr   Nztensor z has shape: z stride u=    – innermost contiguous (stride == 1) dimension not found. )next	enumeratestrider%   rS   rJ   RuntimeError)r'   rB   rE   idxs       r   _get_innermost_stride_dimz!APIBase._get_innermost_stride_dimF  s    
 GG)FMMOO"<"<GGGNN;L!! Q$  Q  QFL  Q  Q&--//  Q  Q  Q      `   `   `6<   `   `QWQ^Q^Q`Q`   `   `   `  a  a  a
r   Optional[Tuple[int, ...]]c                    |dS |                      |          rn|                     ||          t          fdt          |j                  D                       }| j                            d| d|j         d|            |S |j        S )aK  Get the logical shape of a tensor, handling FP4x2 packed datatypes.

        For FP4x2 datatypes, two values are packed per byte. The innermost
        contiguous dimension (with stride 1) contains packed values, so the
        logical shape for that dimension is 2x the physical shape.

        :param tensor: The tensor to get shape from (or None)
        :type tensor: torch.Tensor or None
        :param name: Logical tensor name for logging
        :type name: str
        :return: The logical shape tuple (or None if tensor is None)
        :rtype: Tuple[int, ...] or None
        NrE   c              3  8   K   | ]\  }}|k    r|d z  n|V  dS    Nr   )ri   rj   diminnermost_dim_indexs      r   rl   z(APIBase._tensor_shape.<locals>.<genexpr>j  s:      mm61cQ*=%=%=#''3mmmmmmr   FP4x2 tensor z: physical shape z -> logical shape )r_   rr   tuplern   rJ   r%   r=   )r'   rB   rE   rJ   rz   s       @r   _tensor_shapezAPIBase._tensor_shapeS  s    $ >4>>&!! 	 "&"@"@d"@"S"SmmmmU^_e_kUlUlmmmmmELmtmmflmmfkmmnnnL<r   c                n   |dS |                      |          r|                     ||          t          fdt          |                                          D                       }| j                            d| d|                                 d|            |S |                                S )a[  Get the logical stride of a tensor, handling FP4x2 packed datatypes.

        For FP4x2 datatypes, two values are packed per byte. The strides must
        be adjusted to reflect logical element spacing. All strides are
        multiplied by 2 since each physical element contains 2 logical elements.

        :param tensor: The tensor to get stride from (or None)
        :type tensor: torch.Tensor or None
        :param name: Logical tensor name for logging
        :type name: str
        :return: The logical stride tuple (or None if tensor is None)
        :rtype: Tuple[int, ...] or None
        Nru   c              3  8   K   | ]\  }}|k    r|d z  n|V  dS rw   r   )ri   rj   rk   rz   s      r   rl   z)APIBase._tensor_stride.<locals>.<genexpr>  s:      llAQ*=%=%=AEE1llllllr   r{   z: physical stride z -> logical stride )r_   rr   r|   rn   ro   r%   r=   )r'   rB   rE   stridesrz   s       @r   _tensor_stridezAPIBase._tensor_stridep  s    $ >4>>&!! 	#"&"@"@d"@"S"SllllQZ[a[h[h[j[jQkQklllllGLttttv}}ttkrttuuuN==??"r   tensor_or_shapetorch.Tensor | Tuple[int, ...]rJ   'Tuple[int, ...] | List[Tuple[int, ...]]c                z   |dS t          |t          j                  r|                     ||          n|}t          |t                    r||k    rt          | d| d|           nPt          |t                    r||vrt          | d| d|           nt          dt          |                     |S )aY  Check if the shape of a tensor matches the expected shape(s).

        :param tensor_or_shape: The tensor to get shape from or the shape to check
        :type tensor_or_shape: torch.Tensor | Tuple[int, ...]
        :param shape: expected shape or list of expected shapes
        :type shape: Tuple[int, ...] | List[Tuple[int, ...]]
        :param name: Logical tensor name for logging
        :type name: str
        :raises ValueError: If the shape of the tensor does not match the expected shape(s)
        :return: The logical shape of the tensor
        :rtype: Optional[Tuple[int, ...]]
        Nru   z! tensor shape mismatch: expected , got z( tensor shape mismatch: expected one of z*Expected shape to be a tuple or list, got )rX   rY   rZ   r}   r|   
ValueErrorlisttype)r'   r   rJ   rE   tensor_shapes        r   _check_tensor_shapezAPIBase._check_tensor_shape  s    $ "4ISTcejeqIrIr  Ht))/)EEE  yHeU## 	Yu$$ D!f!f5!f!fXd!f!fggg %t$$ 	Y5(( D!m!mRW!m!m_k!m!mnnn ) W$u++WWXXXr   tensor_or_stridero   1Optional[Tuple[int, ...] | List[Tuple[int, ...]]]stride_orderextra_error_msg1Optional[Tuple[Tuple[int, ...], Tuple[int, ...]]]c                   |dS t          |t          j                  r|                     ||          n|}t	          d t          t          |          d           D                       }|t          |t                    r*||k    r#| d| d| }|r|d	| z  }t          |          nht          |t                    r(||vr#| d
| d| }|r|d	| z  }t          |          n+dt          |           }|r|d	| z  }t          |          |t          |t                    r*||k    r#| d| d| }|r|d	| z  }t          |          nht          |t                    r(||vr#| d| d| }|r|d	| z  }t          |          n+dt          |           }|r|d	| z  }t          |          ||fS )aq  Check if the stride of a tensor matches the expected stride(s) or stride order(s).

        :param tensor_or_stride: The tensor to get stride from or the stride to check
        :type tensor_or_stride: torch.Tensor | Tuple[int, ...]
        :param stride: The expected stride(s)
        :type stride: Tuple[int, ...] | List[Tuple[int, ...]]
        :param stride_order: The expected stride order(s)
        :type stride_order: Tuple[int, ...] | List[Tuple[int, ...]]
        :param name: Logical tensor name for logging
        :type name: str
        :param extra_error_msg: Extra error message to add to the error
        :type extra_error_msg: str
        :raises ValueError: If the stride of the tensor does not match the expected stride order
        :return: The stride and stride order of the tensor
        :rtype: Optional[Tuple[Tuple[int, ...], Tuple[int, ...]]]
        N)NNru   c              3      K   | ]	\  }}|V  
d S Nr   rh   s      r   rl   z/APIBase._check_tensor_stride.<locals>.<genexpr>  &      #g#g$!QA#g#g#g#g#g#gr   c                    | d         S r   r   xs    r   <lambda>z.APIBase._check_tensor_stride.<locals>.<lambda>      abcdae r   keyz" tensor stride mismatch: expected r   : z) tensor stride mismatch: expected one of z+Expected stride to be a tuple or list, got z( tensor stride order mismatch: expected z/ tensor stride order mismatch: expected one of z1Expected stride order to be a tuple or list, got )
rX   rY   rZ   r   r|   sortedrn   r   r   r   )	r'   r   ro   r   rE   r   tensor_stridetensor_stride_order	error_msgs	            r   _check_tensor_stridezAPIBase._check_tensor_stride  s   0 #:LVWginiuLvLv  M++,<4+HHH  }M##g#g&=9Q9QWeWe2f2f2f#g#g#ggg&%(( , F**#' h h6 h hYf h hI& <!%;/%;%;;	$Y///	 +
 FD)) 
, ..#' o oRX o o`m o oI& <!%;/%;%;;	$Y///	 / Y$v,,XX	" 8!7o!7!77I +++#,.. ,&,66#' z zQ] z zex z zI& <!%;/%;%;;	$Y///	 7
 L$// 
,&l::#'  !B  !BXd  !B  !Bl  !B  !BI& <!%;/%;%;;	$Y///	 ; ePTUaPbPbdd	" 8!7o!7!77I +++111r   r[   torch.dtype | List[torch.dtype]Optional[torch.dtype]c                   |dS t          |t          j                  r|j        n|}t          |t          j                  r*||k    r#| d| d| }|r|d| z  }t	          |          n\t          |t
                    r(||vr#| d| d| }|r|d| z  }t	          |          nt	          dt          |                     |S )a6  Check if the dtype of a tensor or dtype matches the expected dtype(s).

        :param tensor_or_dtype: The tensor to get dtype from or the dtype to check
        :type tensor_or_dtype: torch.Tensor | torch.dtype
        :param dtype: The expected dtype(s)
        :type dtype: torch.dtype | List[torch.dtype]
        :param name: Logical tensor name for logging
        :type name: str
        :raises ValueError: If the dtype of the tensor does not match the expected dtype(s)
        :return: The dtype of the tensor
        :rtype: Optional[torch.dtype]
        Nz dtype mismatch: expected r   r   z! dtype mismatch: expected one of z0Expected dtype to be a torch.dtype or list, got )rX   rY   rZ   r[   r   r   r   )r'   rU   r[   rE   r   tensor_dtyper   s          r   _check_dtypezAPIBase._check_dtype  s*   & "40:?EL0Y0Yn,,_neU[)) 	_u$$#ZZuZZLZZ	" 8!7o!7!77I +++	 %
 t$$ 	_5((#aaeaaS_aa	" 8!7o!7!77I +++	 ) ]PTUZP[P[]]^^^r   	conditionr   c                (    |rt          |          dS )a  Raise a ValueError if the condition is true.

        :param condition: The condition to check
        :type condition: bool
        :param error_msg: The error message to raise
        :type error_msg: str
        :raises ValueError: If the condition is true
        N)r   r'   r   r   s      r   _value_error_ifzAPIBase._value_error_if  s$      	(Y'''	( 	(r   c                (    |rt          |          dS )a  Raise a NotImplementedError if the condition is true.

        :param condition: The condition to check
        :type condition: bool
        :param error_msg: The error message to raise
        :type error_msg: str
        :raises NotImplementedError: If the condition is true
        N)NotImplementedErrorr   s      r   _not_implemented_error_ifz!APIBase._not_implemented_error_if!  s$      	1%i000	1 	1r   c                (    |rt          |          dS )a  Raise a RuntimeError if the condition is true.

        :param condition: The condition to check
        :type condition: bool
        :param error_msg: The error message to raise
        :type error_msg: str
        :raises RuntimeError: If the condition is true
        N)rp   r   s      r   _runtime_error_ifzAPIBase._runtime_error_if-  s$      	*y)))	* 	*r      assumed_aligncute.Pointerc                    |dS t           j                            t          |j        | j                  |                                t           j        j        |          S )a:  Make a cute.Pointer for a tensor.

        :param tensor: The tensor to make a cute.Pointer for
        :type tensor: torch.Tensor
        :param assumed_align: The assumed alignment of the tensor
        :type assumed_align: int
        :return: A cute.Pointer for the tensor
        :rtype: cute.Pointer
        N)interpret_uint8_as_fp4x2r   )	cuteruntimemake_ptrr
   r[   r    data_ptrAddressSpacegmem)r'   rB   r   s      r   _make_cute_pointerzAPIBase._make_cute_pointer9  s\     >4|$$)&,QUQopppOO"'	 % 
 
 	
r   5Tuple[cute.Pointer, Tuple[int, ...], Tuple[int, ...]]c                
   |dS |                      ||          }|                     ||          }|                     ||          }t          d t	          t          |          d           D                       }|||fS )a  Make a cute.Pointer, shape, and order for a tensor.

        :param tensor: The tensor to make a cute.Pointer, shape, and order for
        :type tensor: torch.Tensor
        :param assumed_align: The assumed alignment of the tensor
        :type assumed_align: int
        :param name: Logical tensor name for logging
        :type name: str
        :return: A cute.Pointer, shape, and stride order for the tensor
        :rtype: Tuple[cute.Pointer, Tuple[int, ...], Tuple[int, ...]]
        N)NNNr   ru   c              3      K   | ]	\  }}|V  
d S r   r   rh   s      r   rl   z7APIBase._make_cute_tensor_descriptor.<locals>.<genexpr>_  r   r   c                    | d         S r   r   r   s    r   r   z6APIBase._make_cute_tensor_descriptor.<locals>.<lambda>_  r   r   r   )r   r}   r   r|   r   rn   )r'   rB   r   rE   
tensor_ptrr   r   r   s           r   _make_cute_tensor_descriptorz$APIBase._make_cute_tensor_descriptorL  s     >##,,V=,QQ
))&t)<<++F+>>##g#g&=9Q9QWeWe2f2f2f#g#g#ggg<)<<<r   )r   r   r   )r+   r,   r   r-   )r+   r,   r0   r   r   r   )r   r   )r   r-   )r:   r,   r   r;   )rB   rC   rD   r   rE   rF   r   rC   )rU   rV   r   r   )rd   )rB   re   rE   rF   r   r   )rB   rC   rE   rF   r   rs   )r   r   rJ   r   rE   rF   r   rs   )NNrd   rd   )r   r   ro   r   r   r   rE   rF   r   rF   r   r   )rd   rd   )
rU   rV   r[   r   rE   rF   r   rF   r   r   )r   r   r   rF   r   r-   )r   )rB   re   r   r   r   r   )r   rd   )rB   re   r   r   rE   rF   r   r   )r$   
__module____qualname____doc__r(   r   r*   r/   r4   r6   r9   rA   rO   rT   r_   rc   rr   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   !   s       - -^B B B    ^2     ^:  37"	% % % % % ^%N@ @ @ @(E E E E$   ,   .   2n n n nA A A A               @ # # # # #B 	    D EIJN!A2 A2 A2 A2 A2N !$ $ $ $ $L
( 
( 
( 
(
1 
1 
1 
1
* 
* 
* 
*
 
 
 
 
( JL= = = = = = =r   r   c                  2     e Zd ZdZ fdZd Z fdZ xZS )	TupleDicta  A dictionary that supports tuple unpacking.

    This class extends dict to allow unpacking like a tuple while still
    providing dictionary-style key access. The unpacking order is determined
    by the _keys attribute which preserves insertion order.

    Example:
        >>> result = TupleDict(a=1, b=2, c=3)
        >>> x, y, z = result  # Unpacks as (1, 2, 3)
        >>> result['a']  # Returns 1
        >>> result[0]  # Returns 1 (integer indexing)
    c                     t                      j        |i | t          |                                           | _        d S r   )superr(   r   keys_keys)r'   r2   r3   r#   s      r   r(   zTupleDict.__init__q  s:    $)&)))$))++&&


r   c                *      fd j         D             S )z;Iterate over values in insertion order for tuple unpacking.c              3  (   K   | ]}|         V  d S r   r   )ri   kr'   s     r   rl   z%TupleDict.__iter__.<locals>.<genexpr>x  s'      ,,AQ,,,,,,r   )r   r&   s   `r   __iter__zTupleDict.__iter__v  s    ,,,,,,,,r   c                T   t          |t                    rr|dk     s|t          | j                  k    r(t	          d| dt          | j                   d          t                                          | j        |                   S t                                          |          S )z-Support both string keys and integer indices.r   zindex z! out of range for TupleDict with z items)rX   r   lenr   
IndexErrorr   __getitem__)r'   r   r#   s     r   r   zTupleDict.__getitem__z  s    c3 	8Qww#TZ00 !g#!g!gPSTXT^P_P_!g!g!ghhh77&&tz#777ww""3'''r   )r$   r   r   r   r(   r   r   __classcell__)r#   s   @r   r   r   c  sj         ' ' ' ' '
- - -( ( ( ( ( ( ( ( (r   r   )r   r   r   r   r   r   )r   r   r   r   )r   
__future__r   abcr   r   typingr   r   r   r	   r!   cuda.bindings.driverbindingsdriverr?   r>   rY   cutlass.cuter   cudnn.datatypesr
   r   r   r   dictr   r   r   r   <module>r      sH    # " " " " " # # # # # # # # - - - - - - - - - - - -  # # # # # # # # #         9 9 9 9 9 9   ( ( ( (
= = = = =c = = =D( ( ( ( ( ( ( ( ( (r   