
    `i~                        d dl Z d dlmZmZ d dlZd dlmZ d dlmZm	Z	m
Z
mZmZ d dlmZ d dlmZ d dlmZ  G d d          Z G d	 d
          Zdedej        dej        defdZdedej        defdZdej        deeeef         dedeeeeeef         f         fdZ G d de j                  Z G d d          ZdS )    N)TupleOptional)Boolean)Int32Float32minextract_mlir_valuesnew_from_mlir_values)HardwareInfo)WorkTileInfoc                   @    e Zd ZdZddddedej        fdZd Zd Z	dS )	FmhaStaticTileSchedulerParamsa  A class to represent parameters for the FMHA (Fused Multi-Head Attention) static tile scheduler.

    This class holds the configuration parameters needed to initialize and configure
    the tile scheduler for FMHA operations.

    :ivar is_persistent: Whether to use persistent kernel mode.
    :type is_persistent: bool
    :ivar problem_shape_mbh: Problem shape in (M, B, H) format.
    :type problem_shape_mbh: cute.Shape
    Nlocipis_persistentproblem_shape_mbhc                >    || _         || _        || _        || _        dS )a3  
        Initializes the FmhaStaticTileSchedulerParams with the given parameters.

        :param is_persistent: Whether to use persistent kernel mode.
        :type is_persistent: bool
        :param problem_shape_mbh: Problem shape in (M, B, H) format.
        :type problem_shape_mbh: cute.Shape
        N)r   r   _loc_ip)selfr   r   r   r   s        /home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/cudnn/native_sparse_attention/compression/fmha_helpers.py__init__z&FmhaStaticTileSchedulerParams.__init__-   s&      +!2	    c                     g g c}| _         | j        fD ]=}t          |          }||z  }| j                             t	          |                     >|S N)_values_posr   r	   appendlen)r   valuesobj
obj_valuess       r   __extract_mlir_values__z5FmhaStaticTileSchedulerParams.__extract_mlir_values__B   sb    #%r  *+ 	5 	5C,S11Jj F##C
OO4444r   c           	          g }t          | j        g| j                  D ]:\  }}|                    t	          ||d |                              ||d          };t          | j        gt          |          R d| j        iS )Nr   )	zipr   r   r   r
   r   r   tupler   )r   r    obj_listr!   n_itemss        r   __new_from_mlir_values__z6FmhaStaticTileSchedulerParams.__new_from_mlir_values__J   s    !7 8$:JKK 	& 	&LCOO0fXgX6FGGHHHGHH%FF,T-?c5??cccY]Ybcccr   )
__name__
__module____qualname____doc__boolcuteShaper   r#   r)    r   r   r   r   !   s{        	 	      :   *  d d d d dr   r   c            	           e Zd ZdZddddededej        dej        fdZ	e
dddded	ej        fd
            Ze
dededed	efd            Zdddd	efdZddddZdddddZd Zd ZdS )FmhaStaticTileSchedulerai  A static tile scheduler for FMHA (Fused Multi-Head Attention) operations.

    This class manages the scheduling of work tiles for FMHA kernels, supporting
    both persistent and non-persistent kernel modes. It tracks the current work
    position and advances through the problem space efficiently.

    :ivar _params: Scheduler parameters.
    :type _params: FmhaStaticTileSchedulerParams
    :ivar _blk_coord: Block coordinates.
    :type _blk_coord: cute.Coord
    :ivar _grid_shape: Grid shape for the kernel.
    :type _grid_shape: cute.Shape
    :ivar _is_persistent: Whether to use persistent kernel mode.
    :type _is_persistent: bool
    :ivar _current_work_linear_idx: Current linear work index.
    :type _current_work_linear_idx: Int32
    :ivar _problem_shape_mbh: Problem shape in (M, B, H) format.
    :type _problem_shape_mbh: cute.Layout
    :ivar _num_blocks: Number of blocks in the problem.
    :type _num_blocks: Int32
    :ivar _is_first_block: Whether this is the first block.
    :type _is_first_block: bool
    :ivar num_persistent_sm: Number of persistent SMs.
    :type num_persistent_sm: Int32
    Nr   paramscurrent_work_linear_idx	blk_coord
grid_shapec                <   || _         || _        || _        |j        | _        || _        t          j        |j        ||          | _	        t          j
        | j	        ||          | _        d| _        t          j
        |||          | _        || _        || _        dS )a  
        Initializes the FmhaStaticTileScheduler with the given parameters.

        :param params: Scheduler parameters.
        :type params: FmhaStaticTileSchedulerParams
        :param current_work_linear_idx: Current linear work index.
        :type current_work_linear_idx: Int32
        :param blk_coord: Block coordinates.
        :type blk_coord: cute.Coord
        :param grid_shape: Grid shape for the kernel.
        :type grid_shape: cute.Shape
        r   TN)_params
_blk_coord_grid_shaper   _is_persistent_current_work_linear_idxr/   make_layoutr   _problem_shape_mbhsize_num_blocks_is_first_blocknum_persistent_smr   r   )r   r4   r5   r6   r7   r   r   s          r   r   z FmhaStaticTileScheduler.__init__m   s    , #%$2(?%"&"263KQTY["\"\"\9T%<#"MMM#!%:32!F!F!F	r   returnc                    | j         rOt                      }|                                }t          |t	          j        | j        ||                    ddfS | j        S )a  
        Determine the grid shape for the FMHA kernel.

        For persistent kernels, the grid shape is limited by the number of SMs
        (Streaming Multiprocessors) available on the device. For non-persistent
        kernels, the grid shape matches the problem shape.

        :param params: Scheduler parameters.
        :type params: FmhaStaticTileSchedulerParams

        :return: Grid shape as (M, B, H) tuple.
        :rtype: cute.Shape
        r      )r   r   get_device_multiprocessor_countr   r/   r@   r   )r4   r   r   hardware_infosm_counts        r   get_grid_shapez&FmhaStaticTileScheduler.get_grid_shape   se    (  		,(NNM$DDFFHHdi(@cbQQQRR  ++r   q_tilercurrent_idxseqlen_qc                     || z  |k     S )a7  
        Check if the current work index is valid for the given query sequence length.

        This method verifies that the current work tile index multiplied by the
        query tiler size is within the bounds of the query sequence length.

        :param q_tiler: Query tiler size.
        :type q_tiler: int
        :param current_idx: Current work index.
        :type current_idx: Int32
        :param seqlen_q: Query sequence length.
        :type seqlen_q: Int32

        :return: True if the work is valid, False otherwise.
        :rtype: Boolean
        r1   )rK   rL   rM   s      r   check_valid_work_for_seqlen_qz5FmhaStaticTileScheduler.check_valid_work_for_seqlen_q   s    , W$x//r   c                    | j         r| j        | j        k     n| j        }d}| j         r#| j                            | j        ||          }n| j        }|d         d|d         |d         ff}t          ||          S )aA  
        Get information about the current work tile.

        Determines if the current work is valid and computes the tile coordinates
        based on whether the kernel is persistent or non-persistent.

        :return: WorkTileInfo containing tile coordinates and validity flag.
        :rtype: WorkTileInfo
        )r   r   r   r   r   rF      )r<   r=   rA   rB   r?   get_hier_coordr:   r   )r   r   r   is_validr6   cur_tile_coords         r   get_current_workz(FmhaStaticTileScheduler.get_current_work   s     HLGZt4043CCC`d`t	 	(/>>t?\bejl>mmIII aLq\9Q<(
 NH555r   c                0    |                      ||          S )z}
        Get the initial work tile information.

        :return: Initial WorkTileInfo.
        :rtype: WorkTileInfo
        r   )rU   )r   r   r   s      r   initial_work_tile_infoz.FmhaStaticTileScheduler.initial_work_tile_info   s     $$$444r   rF   )advance_countr   r   c                R    | j         r| xj        || j        z  z  c_        d| _        dS )a7  
        Advance to the next work tile.

        For persistent kernels, advances by the number of persistent SMs.
        For non-persistent kernels, marks that the first block has been processed.

        :param advance_count: Number of steps to advance (default: 1).
        :type advance_count: int
        FN)r<   r=   rC   rB   )r   rX   r   r   s       r   advance_to_next_workz,FmhaStaticTileScheduler.advance_to_next_work   s:      	T))]T=S-SS))$r   c                    t          | j                  }|                    t          | j                             |                    t          | j                             |                    t          | j                             |S r   )r	   r9   extendr=   r:   r;   )r   r    s     r   r#   z/FmhaStaticTileScheduler.__extract_mlir_values__   so    $T\22)$*GHHIII)$/::;;;)$*:;;<<<r   c                 6   t          |          dk    sJ t          | j        |dd                   }t          | j        |d         g          }t          | j        |dd                   }t          | j        |dd                    }t          ||||          S )N
   r            )r   r
   r9   r=   r:   r;   r3   )r   r    
new_paramsnew_current_work_linear_idxnew_blk_coordnew_grid_shapes         r   r)   z0FmhaStaticTileScheduler.__new_from_mlir_values__   s    6{{b    )$,qsDD
&:4;X[abc[dZe&f&f#,T_fQqSkJJ-d.>qrr
KK&z3NP]_mnnnr   )r*   r+   r,   r-   r   r   r/   Coordr0   r   staticmethodrJ   intr   rO   r   rU   rW   rZ   r#   r)   r1   r   r   r3   r3   R   s        B      -  "'  :	 
 J       F  	, , ,-,
 
, , , \,< 000 0 
	0 0 0 \0. '+t 6 6 6 6 6 6 66 -1T 5 5 5 5 5 564D % % % % %  o o o o or   r3   r4   r6   r7   rD   c                 2    t          | |d         ||          S )aq  
    Create a new FMHA static tile scheduler.

    :param params: Scheduler parameters.
    :type params: FmhaStaticTileSchedulerParams
    :param blk_coord: Block coordinates.
    :type blk_coord: cute.Coord
    :param grid_shape: Grid shape.
    :type grid_shape: cute.Shape

    :return: New FmhaStaticTileScheduler instance.
    :rtype: FmhaStaticTileScheduler
    r   )r3   )r4   r6   r7   s      r   !create_fmha_static_tile_schedulerrj   	  s    $ #69Q<JOOOr   r   r   c                 "    t          | |          S )ad  
    Create FMHA static tile scheduler parameters.

    :param is_persistent: Whether to use persistent kernel mode.
    :type is_persistent: bool
    :param problem_shape_mbh: Problem shape in (M, B, H) format.
    :type problem_shape_mbh: cute.Shape

    :return: New FmhaStaticTileSchedulerParams instance.
    :rtype: FmhaStaticTileSchedulerParams
    )r   )r   r   s     r   (create_fmha_static_tile_scheduler_paramsrl     s     )8IJJJr   o_shape	cta_tilerc           	      >   t          |t          j        t          j        | d                   |d                   t          j        | d         d                   t          j        | d         d                   f          }t                              |          }||fS )ak  
    Compute grid parameters for FMHA operation.

    This function calculates the appropriate grid shape and scheduler parameters
    based on the output tensor shape, CTA (Cooperative Thread Array) tiler,
    and whether to use persistent kernel mode.

    The output tensor o has shape (s, d, ((h_r, h_k), b)) where:
    - s: sequence length
    - d: head dimension
    - h_r: number of heads for query
    - h_k: number of heads for key
    - b: batch size

    :param o_shape: Output tensor shape for grid computation.
    :type o_shape: cute.Shape
    :param cta_tiler: CTA tiler dimensions (M, N, K).
    :type cta_tiler: Tuple[int, int, int]
    :param is_persistent: Whether to use persistent kernel mode.
    :type is_persistent: bool

    :return: Tuple of (scheduler_params, grid_shape).
    :rtype: Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]
    r   rQ   rF   )rl   r/   ceil_divr@   r3   rJ   )rm   rn   r   tile_sched_paramsgrids        r   compute_gridrs   0  s    : AM$)GAJ//1>>Igajm$$Igajm$$	
  #112CDDDd""r   c                       e Zd ZdZ ej                    Z ej                    Z ej                    Z ej                    Z	dS )MaskTypeaN  Enumeration of mask types for FMHA operations.

    - RESIDUAL_MASK: Residual mask for handling variable sequence lengths
    - WINDOW_MASK: Window mask for attention which also includes causal and no mask
    - WINDOW_MASK_INFERENCE: Same as the window mask, but has the limitation that the end of q is aligned with the end of k
    N)
r*   r+   r,   r-   enumautoRESIDUAL_MASKWINDOW_MASKWINDOW_MASK_INFERENCECOMPRESSED_CAUSAL_MASKr1   r   r   ru   ru   _  sR          DIKKM$)++K%DIKK&TY[[r   ru   c                      e Zd ZdZej        	 	 ddedej        dej        de	de	de
e	         d	e
e	         d
e	fd            Zej        	 ddedej        dej        de	de	de
e	         d
e	fd            Zej        	 	 ddedej        dej        de	de	de
e	         d	e
e	         d
ee	e	f         fd            Zej        	 	 ddedej        dej        de	de	de
e	         d	e
e	         d
ee	e	f         fd            Zej        	 	 ddedej        dej        de	de	de
e	         d	e
e	         d
e	fd            Zej        	 	 	 ddedej        dej        de	de	de
e	         d	e
e	         de
e	         d
e	fd            Zej        	 	 ddedej        dej        de	de	de
e	         d	e
e	         d
e	fd            Zej        	 	 ddedej        dej        de	de	de
e         d	e
e         fd            ZdS )	FusedMaska  A fused mask implementation for FMHA operations.

    This class handles different types of attention masks including no mask,
    residual mask for variable sequence lengths, and causal mask for
    autoregressive attention patterns.

    The class provides methods to:
    - Calculate trip counts for different mask types
    - Apply masks to attention scores
    - Handle masked and unmasked trip calculations
    N	mask_typer6   
tile_shaperM   seqlen_kwindow_size_leftwindow_size_rightrD   c                    d}t          j        | t          j        u          rdn||z
  }t          j        | t          j        k              rt          j        ||d                   }t          j        | t          j        k    p| t          j        k              rt          j        |du           rt          j        ||d                   }n|d         dz   |d         z  }	|	|z   |z   }
t          j        |
|d                   }t          j        ||d                   }t          ||          }nt          j        | t          j	        k              r|||z  }|d         dz   |d         z  dz
  |z   |z   }t          j        |dz   |z  |d                   }t          j        ||d                   }t          dt          ||                    }t                              | |||||          }||z
  }|S )a  
        Calculate the number of trips needed for the current block.

        The trip count depends on the mask type and the block coordinates.
        For causal masks, it considers the autoregressive constraint.

        :param mask_type: Type of mask to use
        :type mask_type: utils.MaskType
        :param blk_coord: Block coordinates.
        :type blk_coord: cute.Coord
        :param tile_shape: Shape of the tile.
        :type tile_shape: cute.Shape
        :param seqlen_q: Query sequence length for attention computation.
        :type seqlen_q: Int32
        :param seqlen_k: Key sequence length for attention computation.
        :type seqlen_k: Int32
        :param window_size_left: Left-side sliding window size for attention masking.
        :type window_size_left: Optional[Int32]
        :param window_size_right: Right-side sliding window size for attention masking.
        :type window_size_right: Optional[Int32]

        :return: Number of trips needed.
        :rtype: Int32
        r   rF   N)cutlass
const_exprru   rz   rx   r/   rp   ry   r   r{   maxr}   get_trip_start)r~   r6   r   rM   r   r   r   resultoffset	max_idx_qidx_ktmp_blocks_kmax_blocks_kcompression_factor	block_endstart_blocks                   r   get_trip_countzFusedMask.get_trip_countz  s   D ((:X)XYYr_gjr_ri8+AABB 	<]8Z];;Fi8+??n9PXPnCnoo 	=!"3t";<< 9xA??&q\A-A>	!F*->>#}UJqMBB#}Xz!}EE\<88	X-L LMM 	=!)X!5"1)Z]:Q>GJ[[I=9q==O*OR\]^R_``L=:a=AALCl;;<<F..y)ZQY[ceuvv+%r   c                     d}t          j        | t          j        u          rdn||z
  }t          j        |du          r4|d         |d         z  }||z   |z
  }	|	|d         z  }
t	          |
|          }|S )a  
        Get the start of the trip for the current block.

        :param mask_type: Type of mask to use
        :type mask_type: utils.MaskType
        :param blk_coord: Block coordinates.
        :type blk_coord: cute.Coord
        :param tile_shape: Shape of the tile.
        :type tile_shape: cute.Shape
        :param seqlen_q: Query sequence length for attention computation.
        :type seqlen_q: Int32
        :param seqlen_k: Key sequence length for attention computation.
        :type seqlen_k: Int32
        :param window_size_left: Left-side sliding window size for attention masking.
        :type window_size_left: Optional[Int32]
        :param window_size_right: Right-side sliding window size for attention masking.
        :type window_size_right: Optional[Int32]
        r   NrF   )r   r   ru   rz   r   )r~   r6   r   rM   r   r   r   r   	min_idx_qr   r   s              r   r   zFusedMask.get_trip_start  s    6 ((:X)XYYr_gjr_r.d:;; 	/!!z!}4I&)99E JqM1Lv..Fr   c           	      d   t          j        | t          j        u          rdn||z
  }t                              | |||||          }t                              | ||||||          }	|d         dz   |d         z  |z   |z
  }
t          t          |
|d         z  d          |	|z   dz
            }||fS )a  
        Get the begin and end tile idx for the leading mask.

        :param mask_type: Type of mask to use
        :type mask_type: utils.MaskType
        :param blk_coord: Block coordinates.
        :type blk_coord: cute.Coord
        :param tile_shape: Shape of the tile.
        :type tile_shape: cute.Shape
        :param seqlen_q: Query sequence length for attention computation.
        :type seqlen_q: Int32
        :param seqlen_k: Key sequence length for attention computation.
        :type seqlen_k: Int32
        :param window_size_left: Left-side sliding window size for attention masking.
        :type window_size_left: Optional[Int32]
        :param window_size_right: Right-side sliding window size for attention masking.
        :type window_size_right: Optional[Int32]

        :return: Tuple of (begin, end) tile idx for the leading mask.
        :rtype: Tuple[Int32, Int32]
        r   rF   )	r   r   ru   rz   r}   r   r   r   r   )r~   r6   r   rM   r   r   r   r   leading_mask_begin
trip_countr   leading_mask_ends               r   get_leading_mask_idzFusedMask.get_leading_mask_id  s    > ((:X)XYYr_gjr_r&55iJX`bjl|}}--
 

 q\A%A6?BRR	s9
1#=qAA:PbCbefCfgg!#333r   c           	      n   t          j        | t          j        u          rdn||z
  }t                              | |||||          }t                              | ||||||          }	|d         |d         z  |z   |z   }
t          t          |
|d         z  |	|z   dz
            d          }|	|z   dz
  }||fS )a  
        Get the begin and end tile idx for the trailing mask.

        :param mask_type: Type of mask to use
        :type mask_type: utils.MaskType
        :param blk_coord: Block coordinates.
        :type blk_coord: cute.Coord
        :param tile_shape: Shape of the tile.
        :type tile_shape: cute.Shape
        :param seqlen_q: Query sequence length for attention computation.
        :type seqlen_q: Int32
        :param seqlen_k: Key sequence length for attention computation.
        :type seqlen_k: Int32
        :param window_size_left: Left-side sliding window size for attention masking.
        :type window_size_left: Optional[Int32]
        :param window_size_right: Right-side sliding window size for attention masking.
        :type window_size_right: Optional[Int32]

        :return: Tuple of (begin, end) tile idx for the trailing mask.
        :rtype: Tuple[Int32, Int32]
        r   rF   )	r   r   ru   rz   r}   r   r   r   r   )r~   r6   r   rM   r   r   r   r   
trip_startr   r   trailing_mask_begintrailing_mask_ends                r   get_trailing_mask_idzFusedMask.get_trailing_mask_id  s    > ((:X)XYYr_gjr_r--iJPXZbdtuu
--
 

 aL:a=069<MM	!#i:a=&@*zBY\]B]"^"^`abb&3a7"$555r   c           	          d}t          j        |du          r+t                              | ||||||          \  }}	|	|z
  dz   }|S )a  
        Calculate the number of masked trips for the leading mask.

        This is used for blocks that need special handling due to masking.

        :param mask_type: Type of mask to use
        :type mask_type: utils.MaskType
        :param blk_coord: Block coordinates.
        :type blk_coord: cute.Coord
        :param tile_shape: Shape of the tile.
        :type tile_shape: cute.Shape
        :param seqlen_q: Query sequence length for attention computation.
        :type seqlen_q: Int32
        :param seqlen_k: Key sequence length for attention computation.
        :type seqlen_k: Int32
        :param window_size_left: Left-side sliding window size for attention masking.
        :type window_size_left: Optional[Int32]
        :param window_size_right: Right-side sliding window size for attention masking.
        :type window_size_right: Optional[Int32]

        :return: Number of masked trips.
        :rtype: Int32
        r   NrF   )r   r   r}   r   )
r~   r6   r   rM   r   r   r   r   r   r   s
             r   get_masked_leading_countz"FusedMask.get_masked_leading_count:  sk    B .d:;; 
	?3<3P3P !4 40 0 &(::Q>Fr   r   	rem_countc           	      B   d}t          j        | t          j        k    p| t          j        k              rt          j        |du          ryt
                              | ||||||          \  }	}
t          j        |du          r8t
                              | ||||||          \  }}|	|k    r|
|z
  }n|
|	z
  dz   }n|
|	z
  dz   }nt          j        | t          j        k              r||d         z  dk    rd}nd}n| t          j	        k    r||z  }|d         |d         z  }t          ||d         dz   |d         z  dz
            }t
                              | ||||||          }|dz   |z  |d         z  }|dz   |z  |d         z  }t          ||z
  |          }||z   S )a  
        Calculate the number of masked trips for the trailing mask.

        This is used for blocks that need special handling due to masking.

        :param mask_type: Type of mask to use
        :type mask_type: utils.MaskType
        :param blk_coord: Block coordinates.
        :type blk_coord: cute.Coord
        :param tile_shape: Shape of the tile.
        :type tile_shape: cute.Shape
        :param seqlen_q: Query sequence length for attention computation.
        :type seqlen_q: Int32
        :param seqlen_k: Key sequence length for attention computation.
        :type seqlen_k: Int32
        :param window_size_left: Left-side sliding window size for attention masking.
        :type window_size_left: Optional[Int32]
        :param window_size_right: Right-side sliding window size for attention masking.
        :type window_size_right: Optional[Int32]
        :param rem_count: Remaining count from previous calculations.
        :type rem_count: Int32

        :return: Number of masked trips.
        :rtype: Int32
        r   NrF   )r   r   ru   ry   rz   r}   r   r   rx   r{   r   r   r   )r~   r6   r   rM   r   r   r   r   r   r   r   r   r   r   block_startr   r   mask_start_tripmask_end_trips                      r   get_masked_trailing_countz#FusedMask.get_masked_trailing_countj  s   H i8+??n9PXPnCnoo 5	F!"34"?@@ I9B9W9W$%: :6#%6 %&6d&BCC I;D;X;X!!"  ()< <8&(8 +.>>>!25E!E!25H!H1!L.1DDqHF	X-C CDD 	F*Q-'1,,(999!)X!5#A,A6KHy|a'7:a=&H1&LMMI"11 ! J !,a4FF:VW=XO'!m0BBzRS}TM8*EEF	!!r   c                     t                               | ||||||          t                               | ||||||          z
  t                               | ||||||d          z
  }|S )a  
        Calculate the number of unmasked trips for the current block.

        This represents the number of trips that don't require special
        masking treatment.

        :param mask_type: Type of mask to use
        :type mask_type: utils.MaskType
        :param blk_coord: Block coordinates.
        :type blk_coord: cute.Coord
        :param tile_shape: Shape of the tile.
        :type tile_shape: cute.Shape
        :param seqlen_q: Query sequence length for attention computation.
        :type seqlen_q: Int32
        :param seqlen_k: Key sequence length for attention computation.
        :type seqlen_k: Int32
        :param window_size_left: Left-side sliding window size for attention masking.
        :type window_size_left: Optional[Int32]
        :param window_size_right: Right-side sliding window size for attention masking.
        :type window_size_right: Optional[Int32]

        :return: Number of unmasked trips.
        :rtype: Int32
        r   )r}   r   r   r   )r~   r6   r   rM   r   r   r   r   s           r   get_unmasked_trip_countz!FusedMask.get_unmasked_trip_count  s    F $$ !  00 ! $ 11 !	 	% 	< r   acc_qkindex_qkc                 N   d}t          j        | t          j        k              r||z
  }t          j        t          j        |                    D ]}||         \  }	}
t          j        |dup|du          rlt          j        | t          j        k              rH||z  }|	dz   |z  dz
  |
k     s|
|k    rt          j	         ||<   |
|k    s|	|k    rt          j	         ||<   nt          j        |du           r9|	|z   |z   |
k     rt          j	         ||<   |
|k    s|	|k    rt          j	         ||<   nt          j        |du           r9|	|z   |z
  |
k    rt          j	         ||<   |
|k    s|	|k    rt          j	         ||<   ndt          |	|z   |z   |          }t          d|	|z   |z
            }|
|k    s|
|k     rt          j	         ||<   |
|k    s|	|k    rt          j	         ||<   t          j        | t          j        k              r|
|k    s|	|k    rt          j	         ||<   dS )a  
        Apply the appropriate mask to the attention scores.

        This method modifies the attention scores (acc_qk) based on the mask type
        and the positions in the index tensor.

        :param mask_type: Type of mask to use
        :type mask_type: utils.MaskType
        :param acc_qk: Accumulated QK attention scores tensor.
        :type acc_qk: cute.Tensor
        :param index_qk: Index tensor containing position information.
        :type index_qk: cute.Tensor
        :param seqlen_k: Key sequence length for attention computation.
        :type seqlen_k: Int32
        :param seqlen_q: Query sequence length for attention computation.
        :type seqlen_q: Optional[int]
        :param window_size_left: Left-side sliding window size for attention masking.
        :type window_size_left: Optional[int]
        :param window_size_right: Right-side sliding window size for attention masking.
        :type window_size_right: Optional[int]
        r   NrF   )r   r   ru   rz   ranger/   r@   r{   r   infr   r   rx   )r~   r   r   rM   r   r   r   r   iindex_qindex_kr   max_K_indexmin_K_indexs                 r   
apply_maskzFusedMask.apply_mask  s~   @ i8+IIJJ 	)(Fty0011 	- 	-A'{GW!"2$">"_BS[_B_`` 1%i83R&RSS 1)1X)=&!(::Q>HHGW_L_L_%,[Lq	(**g.A.A%,[Lq	'(8D(@AA 1'*;;gEE%,[Lq	(**g.A.A%,[Lq	'(9T(ABB 1'*::WDD%,[Lq	(**g.A.A%,[Lq	"%g&69J&JH"U"UK"%a6)9<L)L"M"MK,,+0E0E%,[Lq	(**g.A.A%,[Lq	!)x/E"EFF -h&&'X*=*=!(F1I;	- 	-r   )NNr   )NNr   )r*   r+   r,   r-   r/   jitru   rf   r0   r   r   r   r   r   r   r   r   r   r   Tensorrh   r   r1   r   r   r}   r}   m  ss       
 
 
X -1-1< <<:< J< 	<
 < #5/< $E?< 
< < < X<| 
X -1! !!:! J! 	!
 ! #5/! 
! ! ! X!F 
X -1-1+4 +4+4:+4 J+4 	+4
 +4 #5/+4 $E?+4 
ue|	+4 +4 +4 X+4Z 
X -1-1,6 ,6,6:,6 J,6 	,6
 ,6 #5/,6 $E?,6 
ue|	,6 ,6 ,6 X,6\ 
X -1-1- --:- J- 	-
 - #5/- $E?- 
- - - X-^ 
X -1-1%&\" \"\":\" J\" 	\"
 \" #5/\" $E?\" E?\" 
\" \" \" X\"| 
X -1-1? ??:? J? 	?
 ? #5/? $E?? 
? ? ? X?B 
X +/+/?- ?-?-?- +?- 	?-
 ?- #3-?- $C=?- ?- ?- X?- ?- ?-r   r}   )rv   typingr   r   r   cutlass.cute.typingr   cutlass.cutlass_dslr   r   r   r	   r
   cutlass.utils.hardware_infor   cutlass.utilsr   cutlass.cuter/   r   r3   rf   r0   rj   r.   rl   rh   rs   Enumru   r}   r1   r   r   <module>r      sj    " " " " " " " "  ' ' ' ' ' '              5 4 4 4 4 4 & & & & & &      .d .d .d .d .d .d .d .dbto to to to to to to tonP)PzP 
P 	P P P P*KKzK #K K K K$'#Z'#S#s]#'# '# (%S#*>>?	'# '# '# '#^) ) ) ) )ty ) ) )^- ^- ^- ^- ^- ^- ^- ^- ^- ^-r   