
    )`iА                       d Z ddlZddlZddlZddlmZmZ ddlZddlm	Z	 ddl
Z
ddlmZmZmZmZmZmZ ddlmZmZ ddlmZ ddlmZ d	Zd
ZdZdZ ej        d          dhdee
j        z  ez  dz  defd            Z edddde	j!        dedefd            Z"eddddede	j!        de	j!        deddf
d            Z#edddde	j$        de	j!        fd            Z%eddddedefd            Z&edddde	j$        dedefd            Z'edddd edefd!            Z(eddddedefd"            Z)edddd ed#edefd$            Z*edddd ed#edefd%            Z+eddddedeeeeef         fd&            Z,eddddedefd'            Z-edddd(edefd)            Z.edddd*edefd+            Z/edddd,edefd-            Z0edddd ed#edefd.            Z1edddd ed#edefd/            Z2eddddedefd0            Z3edddd ed#edefd1            Z4eddddedefd2            Z5eddddedefd3            Z6edddd ed#edefd4            Z7eddddedefd5            Z8edddd6ed7edeeef         fd8            Z9edddd9ed7edeeef         fd:            Z:edddd ed#edefd;            Z;edddd ed#edefd<            Z<edddd=ed>ed?ed@edAedBedCedDedefdE            Z=e	j>        didGej?        e         fdH            Z@e	j>        dedIedJe	j$        dKedef
dL            ZAe	j>        dedIedJe	j$        de	j!        dMej?        e         dKedefdN            ZBe	j>        de	jC        dIe	jD        dOej?        e         dJe	j$        dMej?        e         dKefdP            ZEe	j>        dQe	j$        dRede	j$        fdS            ZF G dT dU          ZGejH        dVedWedXeIdYedZed[eIdefd\            ZJe	 	 	 	 	 	 	 djd_e
j$        d`e
j$        dae
j$        dbe
j$        dz  dce
j$        dz  dde
j$        dz  deeKdWedZedz  d[eIdee
j$        e
j$        f         fdf            ZLg dgZMdS )kaq  
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Fused Add + RMSNorm + FP4 Quantization using CuTe-DSL
======================================================

High-performance fused kernel for element-wise addition followed by RMS normalization
and FP4 quantization. Supports both NVFP4 (block_size=16, E4M3 scales) and MXFP4
(block_size=32, UE8M0 scales) formats.

    N)CallableTuple)Float32Int32Int64Uint32Uint64Uint8)Tdsl_user_op)llvm   )flashinfer_apig      @g      |@      )maxsizedevicereturnc                     t           j                                        sdS | t           j                                        } t           j                            |           }|j        dz  |j        z   S )a  Get the SM version of a CUDA device.

    Args:
        device: CUDA device to query. Can be an int (device index), torch.device,
            device string (e.g., 'cuda:0'), or None to use current device.

    Returns:
        SM version as an integer (e.g., 100 for SM100).
    P   N
   )torchcudais_availablecurrent_deviceget_device_propertiesmajorminor)r   propss     |/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/flashinfer/cute_dsl/add_rmsnorm_fp4quant.pyget_sm_versionr!   6   s`     :""$$ r~**,,J,,V44E;ek))    locipsmem_ptrpeer_cta_rank_in_clusterc                   |                      ||                                          }t          t          j        t          j                    ||                                gddddt          j        j                            S )z?Map smem pointer to address at another CTA rank in the cluster.r#   z$mapa.shared::cluster.u32 $0, $1, $2;=r,r,rFhas_side_effectsis_align_stackasm_dialect)	tointir_valuer   r   
inline_asmr   i32
AsmDialectAD_ATT)r&   r'   r$   r%   smem_ptr_i32s        r    set_block_rankr5   N   s}    
 >>cb>11::<<LEGG3<<>>?2" .	
 	
 	

 
 
r"   valmbar_ptrc          	         t          ||||                                          }t          ||||                                          }t          j        d||                     ||          |gddddt          j        j                   dS )zDStore Float32 value to shared memory on a remote CTA in the cluster.r#   NzIst.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [$0], $1, [$2];zr,f,rTFr*   )r5   r/   r   r0   r2   r3   )r6   r&   r7   r'   r$   r%   remote_smem_ptr_i32remote_mbar_ptr_i32s           r    store_shared_remoter;   a   s     )*  hjj  )*  hjj  	O	cllsrl::<OPSO*     r"   xc                L    | j         t          j        || j        ||          z   S )z/Get pointer to element at coordinate in tensor.r#   )iteratorcutecrd2idxlayout)r<   coordr$   r%   s       r    elem_pointerrC   }   s&     :UAH#"EEEEEr"   base_ptrvaluec          	          t          j        dt          |                               ||          t	          |                              ||          gddddt           j        j                   dS )zStore 64 bits to global memory.Nr#   zst.global.u64 [$0], $1;zl,lTFr*   )r   r0   r   r/   r	   r2   r3   )rD   rE   r$   r%   s       r    st_global_u64rG      s}     	O(OO$$$445MM""sr"22	
 	"O*     r"   tensoroffsetc                    | j         t          |          z   }t          j        t	          j                    |j        ||          }t          |          S )z2Get the memory address of tensor[offset] as Int64.r#   )r>   r   r   ptrtointr   i64llvm_ptrr   )rH   rI   r$   r%   elem_ptrptr_ints         r    get_ptr_as_int64rP      sC     v.HmAEGGX%6CBGGGG>>r"   ac                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z-Fast reciprocal using PTX rcp.approx.ftz.f32.r#   zrcp.approx.ftz.f32 $0, $1;=f,fFr*   r   r   r0   r   f32r/   r2   r3   )rQ   r$   r%   s      r    rcp_approx_ftzrV      sf     EGGQZZ  SR 001(" .	
 	
 	

 
 
r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z"Compute absolute value of float32.r#   zabs.f32 $0, $1;rS   Fr*   rT   r6   r$   r%   s      r    fabs_f32rY      sf     EGGS\\""sr"223" .	
 	
 	

 
 
r"   bc                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z"Compute max of two float32 values.r#   zmax.f32 $0, $1, $2;=f,f,fFr*   rT   rQ   rZ   r$   r%   s       r    fmax_f32r^           EGGQZZ  SR 00'!**2E2E#RT2E2U2UV!" .	
 	
 	

 
 
r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z8Compute min of two float32 values (branchless clamping).r#   zmin.f32 $0, $1, $2;r\   Fr*   rT   r]   s       r    fmin_f32ra      r_   r"   c                6   t          j        t           j                            t	          j                    t	          j                    t	          j                    t	          j                    g          t          |                               ||          gddddt           j        j	        ||	  	        }t          j
        t	          j                    |dg||          }t          j
        t	          j                    |dg||          }t          j
        t	          j                    |dg||          }t          j
        t	          j                    |d	g||          }t          |          t          |          t          |          t          |          fS )
z.Load 128 bits (4 x uint32) from global memory.r#   z(ld.global.v4.u32 {$0, $1, $2, $3}, [$4];z=r,=r,=r,=r,lFr+   r,   r-   r$   r%   r      r      )r   r0   
StructTypeget_literalr   r1   r   r/   r2   r3   extractvaluer   )rD   r$   r%   resultv0v1v2v3s           r    ld_global_v4_u32rn      s<   
 _##QUWWaeggquww$HII	x	!	!cb	!	1	122O*
 
 
F 
	1577FQCSR	@	@	@B		1577FQCSR	@	@	@B		1577FQCSR	@	@	@B		1577FQCSR	@	@	@B"::vbzz6"::vbzz99r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )zConvert float32 to FP8 E4M3.r#   a  
            {
                .reg .b16 fp8_pair;
                .reg .f32 zero;
                mov.f32 zero, 0f00000000;
                cvt.rn.satfinite.e4m3x2.f32 fp8_pair, zero, $1;
                cvt.u32.u16 $0, fp8_pair;
            }
            =r,fFr*   	r   r   r0   r   r1   r   r/   r2   r3   rX   s      r    cvt_f32_to_e4m3rr      sj     EGGS\\""sr"223 " .	
 	
 	
  r"   fp8_valc                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z3Convert FP8 E4M3 to float32 and compute reciprocal.r#   a
  
            {
                .reg .pred p_zero;
                .reg .u32 exp_u, mant_u;
                .reg .s32 exp_s;
                .reg .f32 exp_f, mant_f, fp8_float, result;
                setp.eq.u32 p_zero, $1, 0;
                and.b32 mant_u, $1, 7;
                shr.b32 exp_u, $1, 3;
                and.b32 exp_u, exp_u, 15;
                sub.s32 exp_s, exp_u, 7;
                cvt.rn.f32.s32 exp_f, exp_s;
                ex2.approx.f32 exp_f, exp_f;
                cvt.rn.f32.u32 mant_f, mant_u;
                fma.rn.f32 mant_f, mant_f, 0f3E000000, 0f3F800000;
                mul.f32 fp8_float, exp_f, mant_f;
                rcp.approx.ftz.f32 result, fp8_float;
                selp.f32 $0, 0f00000000, result, p_zero;
            }
            =f,rFr*   	r   r   r0   r   rU   r   r/   r2   r3   )rs   r$   r%   s      r    fp8_e4m3_to_f32_and_rcprw     sj     EGGG__%%#"%556( " .5	
 	
 	
  r"   max_valc                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z0Convert float32 max value to UE8M0 scale factor.r#   aa  
            {
                .reg .pred p_zero, p_neg, p_ovf;
                .reg .f32 log2_val;
                .reg .s32 exp_int, result;
                setp.le.f32 p_zero, $1, 0f00000000;
                lg2.approx.f32 log2_val, $1;
                cvt.rpi.s32.f32 exp_int, log2_val;
                add.s32 result, exp_int, 127;
                setp.lt.s32 p_neg, result, 0;
                setp.gt.s32 p_ovf, result, 255;
                selp.s32 result, 0, result, p_neg;
                selp.s32 result, 255, result, p_ovf;
                selp.s32 $0, 0, result, p_zero;
            }
            rp   Fr*   rq   )rx   r$   r%   s      r    cvt_f32_to_ue8m0rz   :  sl     EGGW&&32&667  " .-	
 	
 	
  r"   	ue8m0_valc                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z4Convert UE8M0 to output_scale (1 / 2^(ue8m0 - 127)).r#   a  
            {
                .reg .pred p_zero;
                .reg .s32 neg_exp;
                .reg .f32 neg_exp_f, result;
                setp.eq.u32 p_zero, $1, 0;
                sub.s32 neg_exp, 127, $1;
                cvt.rn.f32.s32 neg_exp_f, neg_exp;
                ex2.approx.f32 result, neg_exp_f;
                selp.f32 $0, 0f00000000, result, p_zero;
            }
            ru   Fr*   rv   )r{   r$   r%   s      r    ue8m0_to_output_scaler}   Y  sl     EGGI''CB'778 " .%	
 	
 	
  r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z'Multiply two Half2 values element-wise.r#   zmul.f16x2 $0, $1, $2;r)   Fr*   r   r   r0   r   r1   r/   r2   r3   r]   s       r    	half2_mulr   y       EGGAYYCB//1C1CPR1C1S1ST#" .	
 	
 	

 
 
r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z)Multiply two BFloat2 values element-wise.r#   zmul.bf16x2 $0, $1, $2;r)   Fr*   r   r]   s       r    bfloat2_mulr          EGGAYYCB//1C1CPR1C1S1ST$" .	
 	
 	

 
 
r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )zHalf2 absolute value.r#   and.b32 $0, $1, 0x7FFF7FFF;=r,rFr*   r   r<   r$   r%   s      r    habs2r     f     EGGAYYCB//0)" .	
 	
 	

 
 
r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z-Half2 max - element-wise max of 2 fp16 pairs.r#   zmax.f16x2 $0, $1, $2;r)   Fr*   r   r]   s       r    hmax2r     r   r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z1Extract max of 2 fp16 values in half2 as float32.r#   z
            {
                .reg .b16 h0, h1;
                .reg .f32 f0, f1;
                mov.b32 {h0, h1}, $1;
                cvt.f32.f16 f0, h0;
                cvt.f32.f16 f1, h1;
                max.f32 $0, f0, f1;
            }
            ru   Fr*   rv   r   s      r    hmax_to_f32r     sj     EGGAYYCB//0	 " .!	
 	
 	
  r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )zBFloat16x2 absolute value.r#   r   r   Fr*   r   r   s      r    bfloat2_habs2r     r   r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )zBFloat16x2 max.r#   zmax.bf16x2 $0, $1, $2;r)   Fr*   r   r]   s       r    bfloat2_hmax2r     r   r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z(Extract max of 2 bf16 values as float32.r#   ae  
            {
                .reg .b32 lo, hi;
                .reg .f32 f0, f1;
                and.b32 lo, $1, 0xFFFF;
                shr.b32 hi, $1, 16;
                shl.b32 lo, lo, 16;
                shl.b32 hi, hi, 16;
                mov.b32 f0, lo;
                mov.b32 f1, hi;
                max.f32 $0, f0, f1;
            }
            ru   Fr*   rv   r   s      r    bfloat2_hmax_to_f32r     sj     EGGAYYCB//0 " .'	
 	
 	
  r"   h2scalec                P   t          j        t           j                            t	          j                    t	          j                    g          t          |                               ||          t          |                              ||          gddddt           j	        j
        ||	  	        }t          j        t	          j                    |dg||          }t          j        t	          j                    |dg||          }t          |          t          |          fS )z.Convert half2 to float2 and multiply by scale.r#   z
        {
            .reg .b16 h0, h1;
            .reg .f32 f0, f1;
            mov.b32 {h0, h1}, $2;
            cvt.f32.f16 f0, h0;
            cvt.f32.f16 f1, h1;
            mul.f32 $0, f0, $3;
            mul.f32 $1, f1, $3;
        }
        	=f,=f,r,fFrc   r   rd   r   r0   rf   rg   r   rU   r   r/   r   r2   r3   rh   )r   r   r$   r%   ri   f0f1s          r    half2_to_float2_scaledr     s    
 _##QUWWaegg$677					,	,genn.E.E#RT.E.U.UV
	 	O*'  F, 
	1577FQCSR	@	@	@B		1577FQCSR	@	@	@B2;;##r"   bf2c                P   t          j        t           j                            t	          j                    t	          j                    g          t          |                               ||          t          |                              ||          gddddt           j	        j
        ||	  	        }t          j        t	          j                    |dg||          }t          j        t	          j                    |dg||          }t          |          t          |          fS )z3Convert bfloat16x2 to float2 and multiply by scale.r#   aU  
        {
            .reg .b32 lo, hi;
            .reg .f32 f0, f1;
            and.b32 lo, $2, 0xFFFF;
            shr.b32 hi, $2, 16;
            shl.b32 lo, lo, 16;
            shl.b32 hi, hi, 16;
            mov.b32 f0, lo;
            mov.b32 f1, hi;
            mul.f32 $0, f0, $3;
            mul.f32 $1, f1, $3;
        }
        r   Frc   r   rd   r   )r   r   r$   r%   ri   r   r   s          r    bfloat2_to_float2_scaledr   /  s    
 _##QUWWaegg$677			#"		-	-wu~~/F/F3SU/F/V/VW	 	O*-  F2 
	1577FQCSR	@	@	@B		1577FQCSR	@	@	@B2;;##r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z"Add two Half2 values element-wise.r#   zadd.f16x2 $0, $1, $2;r)   Fr*   r   r]   s       r    hadd2r   S  r   r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z$Add two BFloat2 values element-wise.r#   zadd.bf16x2 $0, $1, $2;r)   Fr*   r   r]   s       r    bfloat2_addr   c  r   r"   rj   rk   rl   rm   v4v5v6v7c                   t          t          j        t          j                    t          |                               ||	          t          |                              ||	          t          |                              ||	          t          |                              ||	          t          |                              ||	          t          |                              ||	          t          |                              ||	          t          |                              ||	          gddddt          j        j                            S )zMConvert eight float32 values to eight E2M1 (4-bit) values packed into uint32.r#   a  
            {
                .reg .b8 byte0, byte1, byte2, byte3;
                cvt.rn.satfinite.e2m1x2.f32 byte0, $2, $1;
                cvt.rn.satfinite.e2m1x2.f32 byte1, $4, $3;
                cvt.rn.satfinite.e2m1x2.f32 byte2, $6, $5;
                cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $7;
                mov.b32 $0, {byte0, byte1, byte2, byte3};
            }
            z=r,f,f,f,f,f,f,f,fFr*   rq   )
rj   rk   rl   rm   r   r   r   r   r$   r%   s
             r    cvt_e2m1x8_f32r   s  s0    EGG$$$44$$$44$$$44$$$44$$$44$$$44$$$44$$$44		 !" .3	
 	
 	
  r"       widthc           	      F   t          j        t          | t          j                            rt          j        | j        | j                  }|                    |            t          j	        t          j
        | j                            D ]}t          ||         ||          ||<   |                                S t          j	        t          t          j        |                              D ]0} || t          j                            | d|z                      } 1| S )z8Reduce across threads in a warp using butterfly shuffle.rd   )rI   )cutlass
const_expr
isinstancer?   	TensorSSAmake_rmem_tensorshapedtypestorerange_constexprsizewarp_reduceloadintmathlog2archshuffle_sync_bfly)r6   opr   resis        r    r   r     s     *S$.99:: 	#CIsy99		#(39)=)=>> 	4 	4A QU33CFFxxzz(TYu-=-=)>)>?? 	K 	KA"S$)55c!q&5IIJJCC
r"   r   reduction_bufferinit_valc                 f   t           j                                        }t           j                                        }t          j        |j        d                   }||z  }||z  }|dk    r| |||f<   t           j                                         |}	||k     r
|||f         }	t          |	|          S )z:Block reduction across multiple warps using shared memory.rd   r   )r?   r   lane_idxwarp_idxr   r   barrierr   )
r6   r   r   r   r   r   warps_per_rowrow_idxcol_idxblock_reduce_vals
             r    block_reducer     s     y!!##Hy!!##HI.4Q788M-'G&G1}}-0')*I-+GX,=>',,,r"   	cluster_nc           	      >   t           j                                        }t           j                                        }t           j                                        }|j        d         }	|j        d         d         }
||
z  }||
z  }|dk    rct           j                                        5  |	|
z  }||z  dz  }t           j                            ||           ddd           n# 1 swxY w Y   ||k     r%t          | t          ||||ff          ||           t           j        
                    |d           |
|z  }t          j        |d          }|}t          j        |          D ]$}||dz  z   }||k     r |||||f                   }%t          ||          S )z6Cluster reduction across multiple CTAs using mbarrier.r   rd      N)r'   )phaser   )r?   r   block_idx_in_clusterr   r   r   	elect_onembarrier_arrive_and_expect_txr;   rC   mbarrier_waitceil_divr   r   r   )r6   r   r   r7   r   r   cta_rank_in_clusterr   r   rows_per_blockr   r   r   	num_warpsexpected_bytes	num_totalnum_iterr   r   idxs                       r    cluster_reducer     s    )88::y!!##Hy!!##H%+A.N$*1-a0M-'G&G1}}Y  "" 	N 	N&6I&2Q6NI33HnMMM	N 	N 	N 	N 	N 	N 	N 	N 	N 	N 	N 	N 	N 	N 	N
 ))Gg?R5S+TUU%-		
 	
 	
 	
 	IHA...	)I}Y++H$X.. T TR??!r"24DWc\4RSS',,,s   ).C##C'*C'threads_per_rowc                    |                      ||d          }t          j        j        t          j        t          j        j        t          j        j        i|         }t          |d          }	t          |||	          }
t          |dz  d          }t          j        |dk    p|dk              r>t          j        |dk              rt          |
|||          S t          |
|||||          S |
S )z,Row reduction with optional cluster support.r   )r   reduction_profiler   )r   rd   )reducer?   ReductionOpADDoperatoraddMAXr   fmaxminr   maxr   r   r   r   )r<   r   r   r   r7   r   r   	local_valwarp_op
warp_widthwarp_valr   s               r    
row_reducer     s     h!DDI 	hldin 	
G _b))J9gZ@@@H2-q11M-!+<y1}== i1n-- 	'3CXNNN!'#3Xy(   r"   tXcXlimitc           
         t          j        t          j        t          j        | ddg          t          j        | dg          t          j        | dg          ft          j        | dg          ddf          t          j                  }t	          j        |j        d                   D ]P}t	          j        |j        d                   D ].}t          j        | d|fd|f         d         |          ||d|f<   /Q|S )z,Create predicate tensor for bounds checking.r   rd   )moder   stride)	r?   r   make_layoutr   r   Booleanr   r   	elem_less)r   r   tXpXrest_vrest_ks        r    predicate_kr     s     	$aV,,,	$aS)))	$aS)))
 Id!---q!4	
 	
 	
 	
 
D )$*Q-88  -djm<< 	 	F&*na[!V+,Q/' 'DF"##	 Kr"   c                   <   e Zd ZdZ	 	 d%dej        dedededededz  d	edz  fd
Z	e
dedej        dedefd            Ze
dedefd            Ze
dedefd            Ze
dedededefd            Ze
dededededef
d            ZdefdZej        dej        dej        dej        dej        dej        dej        ded efd!            Zej        dej        dej        dej        dej        dej        dej        ded ed"ej        d#ej        fd$            ZdS )&AddRMSNormFP4QuantKernelz
    Fused Add + RMSNorm + FP4 Quantization Kernel.

    Computes: h = x + r, y = RMSNorm(h) * w, then quantizes y to FP4.
    Supports both NVFP4 (block_size=16) and MXFP4 (block_size=32) formats.
    Nr   H
block_sizeoutput_swizzledis_fp16
sm_versionscale_formatc                 Z   || _         || _        || _        || _        || _        ||nt                      | _        ||dk    rdnd| _        n|| _        |dv sJ d|             | j        dv s
J d            |                     ||| j                  | _	        || j	        z  | _
        |                     | j
                  | _        |                     | j
                  | _        | j        | j        z  | _        t!          | j        dz  d          | _        |j        d	z  }t&          d	z  |z  | _        t!          d| j
        | j        z  | j        z   dz
  | j        z            | _        | j        | j        z  | j        z  | _        ||z  | _        |r||z  }	|	d
z   dz  | _        d| _        d S d S )Nr   ue8m0e4m3r   r   z!block_size must be 16 or 32, got )r  r  z&scale_format must be 'e4m3' or 'ue8m0'rd      re   r      )r   r   r   r   r   r!   r  r  _compute_cluster_nr   	H_per_cta_compute_threads_per_rowr   _compute_num_threadsnum_threadsr   r   r   r   	COPY_BITSvec_sizenum_vec_blockscols_per_tilenum_sf_blocks_per_rownum_k_tilesk_tile_stride)
selfr   r   r   r   r   r  r  
elem_bytesnum_col_vecss
             r    __init__z!AddRMSNormFP4QuantKernel.__init__A  s    
$.(2(>**NDTDT+5+;+;D ,DX%%%'W:'W'W%%% $55554 655 00E4?KKdn,#<<T^LL44T^DD".$2FF !5!;Q??[A%
!Q*4!^t},t/CCaG#$
 

 "]T-@@4CWW%&*_" 	%
?L ,q 0Q6D!$D	% 	%r"   r   c                    |dk     rdS t           j                            t           j                                                  }|j        }|j        dz  }dD ]2}| |z  dk    rt                              | ||          }||k    r|c S 3dS )a'  Compute optimal cluster size based on H and device shared memory.

        Dynamically determines the minimum cluster_n that fits within the
        device's shared memory limit, making it compatible with different
        GPU architectures (e.g., SM100 with 228KB vs SM120 with 128KB).
        Z   rd   r  )rd   r   r   r  r   r   r   )r   r   r   r   shared_memory_per_block_optinr   r   _estimate_smem_bytes)r   r   r  r   max_smem_bytes	elem_sizer   smem_neededs           r    r	  z+AddRMSNormFP4QuantKernel._compute_cluster_nt  s     ??1
001J1J1L1LMM<K1$	) 	! 	!I9}!!2GG9i K n,,     - rr"   r
  c                 V    | dk    rdS | dk    rdS | dk    rdS | dk    rdS | dk    rdS d	S )
z Compute optimal threads per row.@   r  r   r   i   r   i    @      r
  s    r    r  z1AddRMSNormFP4QuantKernel._compute_threads_per_row  sW     ??1#2$2$2%33r"   c                     | dk    rdndS )z Compute total threads per block.r"  r   r#  r$  r%  s    r    r  z-AddRMSNormFP4QuantKernel._compute_num_threads  s      5((ssc1r"   r   r  c                 f   | |z  }t                               |          }t                               |          }||z  }t          |dz  d          }t          dz  |z  }t          d||z  |z   dz
  |z            }	||	z  |z  }
||
z  |z  }|dk    rd|z  ||z  dz  z   S d|z  ||z  |z  dz  z   dz   S )zEstimate shared memory bytes needed for given configuration.

        This is used to dynamically determine cluster_n based on device
        shared memory limits.
        r   rd   r  r   r   )r   r  r  r   r  )r   r   r  r
  r   r  r   r   r  r  r  
tile_bytess               r    r  z-AddRMSNormFP4QuantKernel._estimate_smem_bytes  s     N	2KKIVV.CCINN$7Or1155>Y.	X%7!;O
 
 !>1OC#m3i?
>>z>N]$BQ$FFF z>N]$BY$NQR$RRUVVVr"   r   r   r  r  c                 <    | |f||ff}||z  df|||z  | z  ff}||fS )zBCreate Thread-Value layout for coalesced vectorized memory access.rd   r$  )r   r   r  r  r   r   s         r    _make_tv_layoutz(AddRMSNormFP4QuantKernel._make_tv_layout  sO     n-~&

 &*^h6HI
 f}r"   c                 h   | j         j        dz  }| j        | j        z  |z  }| j        | j        z  |z  }| j        dk    r7| j        | j        z  |z  }| j        | j        z  |z  }| j        | j        z  dz  }nd}d}| j        | j        z  | j        z  dz  }| j        dk    rdnd}||z   |z   |z   |z   |z   S )z$Calculate shared memory requirement.r  rd   r   r   )r   r   r   r  r   r   )r  r  x_tile_bytesr_tile_bytesw_tile_bytesh_tile_bytesreduction_bytes
mbar_bytess           r    _smem_size_in_bytesz,AddRMSNormFP4QuantKernel._smem_size_in_bytes  s   J$)	*T-??)K*T-??)K>Q.1CCiOL.1CCiOL"1D4FFJOOLL#d&884>IAM  .1,,QQ!
   	
 	
r"   mXmRmWmYmSmGlobalScaleMepsc
                    |                      | j        | j        | j        | j                  \  }
}t          j        |
|          }| j        | j        f}|                     ||||||||||
  
        	                    t          j
        || j                  | j        dg| j        ddgt          j        | j        dk              r
d| j        dgnd|                                 |	           dS )a  Host function to launch the kernel.

        Takes tensors directly via TVM-FFI.
        - mX: Input tensor, shape (M, H), row-major
        - mR: Residual tensor, shape (M, H), row-major
        - mW: Weight tensor, shape (H,)
        - mY: Output FP4 tensor, shape (M, H // 2), row-major (packed)
        - mS: Scale factor tensor, shape depends on swizzle mode
        - mGlobalScale: Global scale tensor, shape (1,), float32
        r   rd   N)gridblockclustersmemstream)r*  r   r   r  r  r?   r   r  kernellaunchr   r   r  r   r   r2  )r  r3  r4  r5  r6  r7  r8  r9  r:  r@  tv_shape	tv_stride	tv_layouttiler_mns                 r    __call__z!AddRMSNormFP4QuantKernel.__call__  s
   0 #22 M	
 
) $Xi@@@	');<BBai	
 	

&-4#677K#Q*!$.1"455Q**))++  
 
 
 
 
r"   rE  rF  c                 "x   t           j                                        \  }}}t           j                                        \  }}}| j        }| j        }| j        }| j        }| j        }t          j
        |dk              r%t           j                                        d         }nt          j
        d          }|	j        d         d         }t          |dz  d          }|
d         }||z  }||z  }t          t          t                              }t          j                                        }|                    |j        t          j        |
d          d          }|                    |j        t          j        |
d          d          }t          j
        |dk              rb|                    |j        t          j        |
d          d          }|                    |j        t          j        |
d          d          }t          j
        |dk              r4|                    t          t          j        ||f          d          }d	} nO|                    t          t          j        |||ff          d          }|                    t.          d
          } t          j
        |dk              r|dk    r t           j                            | d           t           j                                         t           j                                         t           j                                         t          j        |j                  }!t          j        ||
||f          }"t          j        ||
||f          }#t          j        |!|
||f          }$t          j
        |dk              rgt          j        |j        t          j        |
d         fd                    }%t          j         |j!        |%          }&t          j        |&|
d|f          }'t          j"        t           j#        j$        %                                |j        tL                    }(t          j'        |(|	|
          })|)(                    |          }*|)(                    |          }+|*)                    |"          },|**                    |          }-|+)                    |#          }.|+*                    |          }/|*)                    |$          }0t          j
        |dk              rT|)(                    |          }1|1)                    |'          }2|1*                    |          }3|**                    |          }4t          j+        |,          }5t          j+        |.          }6tY          |0|          }7|0d         }8|8d         |k     }9|9r0t          j-        |(|,|-|7           t          j-        |(|.|/|7           t          j
        |dk              rt          j-        |(|2|3|7           t           j        .                                 t           j        /                    d           t          j0        |-|5           t          j0        |/|6           |51                                2                    t                    }:|61                                2                    t                    };|:|;z   }<|<|<z  }=t          j
        |dk              r/|<2                    |j                  }>|43                    |>           ti          |=t           j5        j6        ||| |t          d                    }?|?|z  }@t           j7        8                    |@|z   d          }A|d         }Bt          j
        |dk              r=t           j                                         t           j                                         nt           j        9                                 ||z  |z   }C|C|k     3r||z   dz
  |z  }Dtu          |D          D 3]}E||E|z  z   }F|F|k     3r|F|z  }Gt          j
        |dk              r4t          j
        |dk              rt          |||Gdz   f                   }Ht          |||Gdz   f                   }It          |||Gdz   f                   }Jt          |||Gdz   f                   }Kt          |||Gdz   f                   }Lt          |||Gdz   f                   }Mt          |||Gdz   f                   }Nt          |||Gdz   f                   }Ot          |||Gdz   f                   }Pt          |||Gdz   f                   }Qt          |||Gdz   f                   }Rt          |||Gdz   f                   }St          |||Gdz   f                   }Tt          |||Gdz   f                   }Ut          |||Gdz   f                   }Vt          |||Gd z   f                   }Wt          |||Gdz   f                   }Xt          |||Gdz   f                   }Yt          |||Gdz   f                   }Zt          |||Gdz   f                   }[t          |||Gdz   f                   }\t          |||Gdz   f                   }]t          |||Gdz   f                   }^t          |||Gdz   f                   }_t          |||Gdz   f                   }`t          |||Gdz   f                   }at          |||Gdz   f                   }bt          |||Gdz   f                   }ct          |||Gdz   f                   }dt          |||Gdz   f                   }et          |||Gdz   f                   }ft          |||Gd z   f                   }g|H|Az  |Xz  }h|I|Az  |Yz  }i|J|Az  |Zz  }j|K|Az  |[z  }k|L|Az  |\z  }l|M|Az  |]z  }m|N|Az  |^z  }n|O|Az  |_z  }o|P|Az  |`z  }p|Q|Az  |az  }q|R|Az  |bz  }r|S|Az  |cz  }s|T|Az  |dz  }t|U|Az  |ez  }u|V|Az  |fz  }v|W|Az  |gz  }wtw          |h          }xty          |xtw          |i                    }xty          |xtw          |j                    }xty          |xtw          |k                    }xty          |xtw          |l                    }xty          |xtw          |m                    }xty          |xtw          |n                    }xty          |xtw          |o                    }xty          |xtw          |p                    }xty          |xtw          |q                    }xty          |xtw          |r                    }xty          |xtw          |s                    }xty          |xtw          |t                    }xty          |xtw          |u                    }xty          |xtw          |v                    }xty          |xtw          |w                    }x|B|xz  |z  }yt{          |yt          t|                              }yt          |y          }zt          |zt          d!          z            }{t          |z          |Bz  }|t          j
        | jC                  r|Ft          d          z  }}|Ct          d"          z  t          d          z  }~|Ct          d          z  }|Ft          d          z  }|Ct          d"          z  }| jE        | jF        z  }||z  || jF        z  z   |t          d          z  z   |~t          d          z  z   |}z   }|{||<   n|{||C|Ff<   |h||z  }|i||z  }|j||z  }|k||z  }|l||z  }|m||z  }|n||z  }|o||z  }|p||z  }|q||z  }|r||z  }|s||z  }|t||z  }|u||z  }|v||z  }|w||z  }t          ||||||||          }t          ||||||||          }t          |          t          d          z  t          |          z  }|Gdz  }t          ||C|dz  z  |z             }t          ||           t          ||C|z  |Gz             }t          ||C|z  |Gz   t          d          z             }t          ||C|z  |Gz             }t          ||C|z  |Gz   t          d          z             }t          ||G          }t          ||Gt          d          z             }t          |          \  }}}}t          |          \  }}}}t          |          \  }}}}t          |          \  }}}}t          |          \  }}}}t          |          \  }}}}t          j
        |          rt          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          |          }t          |          }t          |          }t          |¦          }t          |æ          }t          |Ħ          }t          |Ŧ          }t          |Ʀ          }t          ||Ȧ          }t          ||ʦ          }t          ||̦          }t          ||Φ          }t          ||Ц          }t          ||Ҧ          }t          ||Ԧ          }t          |զ          }||Az  }xt          ||A          \  }h}it          ||A          \  }j}kt          ||A          \  }l}mt          ||A          \  }n}ot          ||A          \  }p}qt          ||A          \  }r}st          ||A          \  }t}ut          ||A          \  }v}wnt          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          |          }t          |          }t          |          }t          |¦          }t          |æ          }t          |Ħ          }t          |Ŧ          }t          |Ʀ          }t          ||Ȧ          }t          ||ʦ          }t          ||̦          }t          ||Φ          }t          ||Ц          }t          ||Ҧ          }t          ||Ԧ          }t          |զ          }||Az  }xt          ||A          \  }h}it          ||A          \  }j}kt          ||A          \  }l}mt          ||A          \  }n}ot          ||A          \  }p}qt          ||A          \  }r}st          ||A          \  }t}ut          ||A          \  }v}w|B|xz  |z  }yt{          |yt          t|                              }yt          |y          }zt          |zt          d!          z            }{t          |z          |Bz  }|t          j
        | jC                  r|Ft          d          z  }}|Ct          d"          z  t          d          z  }~|Ct          d          z  }|Ft          d          z  }|Ct          d"          z  }| jE        | jF        z  }||z  || jF        z  z   |t          d          z  z   |~t          d          z  z   |}z   }|{||<   n|{||C|Ff<   |h||z  }|i||z  }|j||z  }|k||z  }|l||z  }|m||z  }|n||z  }|o||z  }|p||z  }|q||z  }|r||z  }|s||z  }|t||z  }|u||z  }|v||z  }|w||z  }t          ||||||||          }t          ||||||||          }t          |          t          d          z  t          |          z  }|Gdz  }t          ||C|dz  z  |z             }t          ||           dt          j
        |dk              rt          |||Gt          d          z   f                   }Ht          |||Gt          d          z   f                   }It          |||Gt          d          z   f                   }Jt          |||Gt          d          z   f                   }Kt          |||Gt          d          z   f                   }Lt          |||Gt          d          z   f                   }Mt          |||Gt          d          z   f                   }Nt          |||Gt          d          z   f                   }Ot          |||Gt          d          z   f                   }Pt          |||Gt          d          z   f                   }Qt          |||Gt          d          z   f                   }Rt          |||Gt          d          z   f                   }St          |||Gt          d          z   f                   }Tt          |||Gt          d          z   f                   }Ut          |||Gt          d          z   f                   }Vt          |||Gt          d           z   f                   }Wt          |||Gt          d          z   f                   }t          |||Gt          d#          z   f                   }t          |||Gt          d$          z   f                   }t          |||Gt          d%          z   f                   }t          |||Gt          d&          z   f                   }t          |||Gt          d'          z   f                   }t          |||Gt          d(          z   f                   }t          |||Gt          d)          z   f                   }t          |||Gt          d*          z   f                   }t          |||Gt          d+          z   f                   }t          |||Gt          d,          z   f                   }t          |||Gt          d-          z   f                   }t          |||Gt          d.          z   f                   }t          |||Gt          d/          z   f                   }t          |||Gt          d0          z   f                   }t          |||Gt          d1          z   f                   }t          |||Gt          d          z   f                   }Xt          |||Gt          d          z   f                   }Yt          |||Gt          d          z   f                   }Zt          |||Gt          d          z   f                   }[t          |||Gt          d          z   f                   }\t          |||Gt          d          z   f                   }]t          |||Gt          d          z   f                   }^t          |||Gt          d          z   f                   }_t          |||Gt          d          z   f                   }`t          |||Gt          d          z   f                   }at          |||Gt          d          z   f                   }bt          |||Gt          d          z   f                   }ct          |||Gt          d          z   f                   }dt          |||Gt          d          z   f                   }et          |||Gt          d          z   f                   }ft          |||Gt          d           z   f                   }gt          |||Gt          d          z   f                   }t          |||Gt          d#          z   f                   }t          |||Gt          d$          z   f                   }t          |||Gt          d%          z   f                   }t          |||Gt          d&          z   f                   }t          |||Gt          d'          z   f                   }t          |||Gt          d(          z   f                   }t          |||Gt          d)          z   f                   }t          |||Gt          d*          z   f                   }t          |||Gt          d+          z   f                   }t          |||Gt          d,          z   f                   }t          |||Gt          d-          z   f                   }t          |||Gt          d.          z   f                   }t          |||Gt          d/          z   f                   }t          |||Gt          d0          z   f                   }t          |||Gt          d1          z   f                   }|H|Az  |Xz  }h|I|Az  |Yz  }i|J|Az  |Zz  }j|K|Az  |[z  }k|L|Az  |\z  }l|M|Az  |]z  }m|N|Az  |^z  }n|O|Az  |_z  }o|P|Az  |`z  }p|Q|Az  |az  }q|R|Az  |bz  }r|S|Az  |cz  }s|T|Az  |dz  }t|U|Az  |ez  }u|V|Az  |fz  }v|W|Az  |gz  }w||Az  |z  }||Az  |z  }||Az  |z  }||Az  |z  }||Az  |z  }||Az  |z  }||Az  |z  }||Az  |z  }||Az  |z  }||Az  |z  } ||Az  |z  }||Az  |z  }||Az  |z  }||Az  |z  }||Az  |z  }||Az  |z  }tw          |h          }xty          |xtw          |i                    }xty          |xtw          |j                    }xty          |xtw          |k                    }xty          |xtw          |l                    }xty          |xtw          |m                    }xty          |xtw          |n                    }xty          |xtw          |o                    }xty          |xtw          |p                    }xty          |xtw          |q                    }xty          |xtw          |r                    }xty          |xtw          |s                    }xty          |xtw          |t                    }xty          |xtw          |u                    }xty          |xtw          |v                    }xty          |xtw          |w                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                     }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xty          |xtw          |                    }xt          j
        | jX        d2k              rG|x|z  }yt          |y          }t          |t          d!          z            }t          |          }|nk|B|xz  |z  }yt{          |yt          t|                              }yt          |y          }zt          |zt          d!          z            }t          |z          |Bz  }|t          j
        | jC                  r|Ft          d          z  }}|Ct          d"          z  t          d          z  }~|Ct          d          z  }|Ft          d          z  }|Ct          d"          z  }| jE        | jF        z  }||z  || jF        z  z   |t          d          z  z   |~t          d          z  z   |}z   }|||<   n|||C|Ff<   |h||z  }|i||z  }|j||z  }|k||z  }|l||z  }|m||z  }|n||z  }|o||z  }|p||z  }|q||z  }|r||z  }|s||z  }|t||z  }|u||z  }|v||z  }|w||z  }|||z  }	|||z  }
|||z  }|||z  }|||z  }|||z  }|||z  }|||z  }|||z  }| ||z  }|||z  }|||z  }|||z  }|||z  }|||z  }|||z  }t          ||||||||          }t          ||||||||          }t          |          t          d          z  t          |          z  }t          |	|
||||||          }t          ||||||||          }t          |          t          d          z  t          |          z  }|C|dz  z  |F|dz  z  z   }t          ||          } t          ||t          d          z             }!t          | |           t          |!|           "yt          ||C|z  |Gz             }"t          ||C|z  |Gz   t          d          z             }#t          ||C|z  |Gz   t          d          z             }$t          ||C|z  |Gz   t          d*          z             }%t          ||C|z  |Gz             }t          ||C|z  |Gz   t          d          z             }t          ||C|z  |Gz   t          d          z             }&t          ||C|z  |Gz   t          d*          z             }'t          ||G          }t          ||Gt          d          z             }t          ||Gt          d          z             }(t          ||Gt          d*          z             })t          |"          \  }}}}t          |#          \  }}}}t          |$          \  }*}+},}-t          |%          \  }.}/}0}1t          |          \  }}}}t          |          \  }}}}t          |&          \  }2}3}4}5t          |'          \  }6}7}8}9t          |          \  }}}}t          |          \  }}}}t          |(          \  }:};}<}=t          |)          \  }>}?}@}At          j
        |          rt          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          |*|2          }Bt          |+|3          }Ct          |,|4          }Dt          |-|5          }Et          |.|6          }Ft          |/|7          }Gt          |0|8          }Ht          |1|9          }It          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          |B|:          }Jt          |C|;          }Kt          |D|<          }Lt          |E|=          }Mt          |F|>          }Nt          |G|?          }Ot          |H|@          }Pt          |I|A          }Qt          |          }t          |          }t          |          }t          |¦          }t          |æ          }t          |Ħ          }t          |Ŧ          }t          |Ʀ          }t          |J          }Rt          |K          }St          |L          }Tt          |M          }Ut          |N          }Vt          |O          }Wt          |P          }Xt          |Q          }Yt          ||Ȧ          }t          ||ʦ          }t          ||̦          }t          ||Φ          }t          |R|S          }Zt          |T|U          }[t          |V|W          }\t          |X|Y          }]t          ||Ц          }t          ||Ҧ          }t          |Z|[          }^t          |\|]          }_t          ||Ԧ          }`t          |^|_          }at          |`|a          }t          |զ          }||Az  }xt          ||A          \  }h}it          ||A          \  }j}kt          ||A          \  }l}mt          ||A          \  }n}ot          ||A          \  }p}qt          ||A          \  }r}st          ||A          \  }t}ut          ||A          \  }v}wt          |J|A          \  }}t          |K|A          \  }}t          |L|A          \  }}t          |M|A          \  }}t          |N|A          \  }} t          |O|A          \  }}t          |P|A          \  }}t          |Q|A          \  }}nt          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          |*|2          }Bt          |+|3          }Ct          |,|4          }Dt          |-|5          }Et          |.|6          }Ft          |/|7          }Gt          |0|8          }Ht          |1|9          }It          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          |B|:          }Jt          |C|;          }Kt          |D|<          }Lt          |E|=          }Mt          |F|>          }Nt          |G|?          }Ot          |H|@          }Pt          |I|A          }Qt          |          }t          |          }t          |          }t          |¦          }t          |æ          }t          |Ħ          }t          |Ŧ          }t          |Ʀ          }t          |J          }Rt          |K          }St          |L          }Tt          |M          }Ut          |N          }Vt          |O          }Wt          |P          }Xt          |Q          }Yt          ||Ȧ          }t          ||ʦ          }t          ||̦          }t          ||Φ          }t          |R|S          }Zt          |T|U          }[t          |V|W          }\t          |X|Y          }]t          ||Ц          }t          ||Ҧ          }t          |Z|[          }^t          |\|]          }_t          ||Ԧ          }`t          |^|_          }at          |`|a          }t          |զ          }||Az  }xt          ||A          \  }h}it          ||A          \  }j}kt          ||A          \  }l}mt          ||A          \  }n}ot          ||A          \  }p}qt          ||A          \  }r}st          ||A          \  }t}ut          ||A          \  }v}wt          |J|A          \  }}t          |K|A          \  }}t          |L|A          \  }}t          |M|A          \  }}t          |N|A          \  }} t          |O|A          \  }}t          |P|A          \  }}t          |Q|A          \  }}t          j
        | jX        d2k              rG|x|z  }yt          |y          }t          |t          d!          z            }t          |          }|nk|B|xz  |z  }yt{          |yt          t|                              }yt          |y          }zt          |zt          d!          z            }t          |z          |Bz  }|t          j
        | jC                  r|Ft          d          z  }}|Ct          d"          z  t          d          z  }~|Ct          d          z  }|Ft          d          z  }|Ct          d"          z  }| jE        | jF        z  }||z  || jF        z  z   |t          d          z  z   |~t          d          z  z   |}z   }|||<   n|||C|Ff<   |h||z  }|i||z  }|j||z  }|k||z  }|l||z  }|m||z  }|n||z  }|o||z  }|p||z  }|q||z  }|r||z  }|s||z  }|t||z  }|u||z  }|v||z  }|w||z  }|||z  }	|||z  }
|||z  }|||z  }|||z  }|||z  }|||z  }|||z  }|||z  }| ||z  }|||z  }|||z  }|||z  }|||z  }|||z  }|||z  }t          ||||||||          }t          ||||||||          }t          |          t          d          z  t          |          z  }t          |	|
||||||          }t          ||||||||          }t          |          t          d          z  t          |          z  }|C|dz  z  |F|dz  z  z   }t          ||          } t          ||t          d          z             }!t          | |           t          |!|           3d	S d	S )3a  Device kernel with cluster sync and Half2 SIMD.

        mGlobalScale contains the global scale value. The kernel reads it and
        computes 1/global_scale, which is multiplied with rstd to apply:
        y = h * rstd * w / global_scale = rmsnorm(h, w) / global_scale
        rd   r   r   rd   r   )orderr   )byte_alignmentr   N)	num_elems)r   r   )num_bits_per_copy)r   ))r   r   r   r   )predg        T)fastmathr   re            r  	   r                     r                                                r  )[r?   r   
thread_idx	block_idxr   r   r  r   r   r   r   r   r   rV   r   FLOAT4_E2M1_MAXutilsSmemAllocatorallocate_tensorelement_typemake_ordered_layoutr   allocate_arrayr   mbarrier_initmbarrier_init_fencecluster_arrive_relaxedcluster_waitmake_identity_tensor
local_tileprependrA   make_tensorr>   make_copy_atomnvgpucpasync	CopyG2SOpr  make_tiled_copy	get_slicepartition_Spartition_Dmake_fragment_liker   copycp_async_commit_groupcp_async_wait_groupautovec_copyr   tor   r   r   r   r   rsqrtr   rangerY   r^   ra   FLOAT8_E4M3_MAXrr   r
   r   rw   r   r   r  r  r   r	   rP   rG   rn   r   r   r   r   r   r   r   r   r   r   r   r   r  rz   r}   (b  r  r3  r4  r5  r6  r7  r8  r9  r:  rE  rF  tidx_bidxr   r   r  r   r   	cluster_yr   r   r   lane_in_rowrow_in_blockfp4_max_rcpr?  sXsRsWsHr   r7   idXgXgRcXmW_expanded_layoutmW_2dgWcopy_atom_load_asynctiled_copy_load
thr_copy_X
thr_copy_RtXgXtXsXtRgRtRsRr   
thr_copy_WtWgWtWsWtHsHtXrXtRrRr   	row_coordrow_in_boundsx_valsr_valsh_valsh_sqh_elemsum_sqmean_sqrstdglobal_scale_valactual_row_idxnum_sf_per_threadsf_itersf_idxblock_startsh0sh1sh2sh3sh4sh5sh6sh7sh8sh9sh10sh11sh12sh13sh14sh15sw0sw1sw2sw3sw4sw5sw6sw7sw8sw9sw10sw11sw12sw13sw14sw15y0y1y2y3y4y5y6y7y8y9y10y11y12y13y14y15max_absscale_floatscale_fp8_u32	scale_fp8	inv_scaleinner_k_idxinner_m_idxouter_m_idx
k_tile_idx
m_tile_idxm_tile_strideswizzled_offsetq0q1q2q3q4q5q6q7q8q9q10q11q12q13q14q15	packed_lo	packed_hipacked64
out_offsetout_ptrh_ptr0h_ptr1r_ptr0r_ptr1w_ptr0w_ptr1x0x1x2x3x4x5x6x7r0r1r2r3r4r5r6r7w0w1w2w3w4w5w6w7h0h1r   h3h4h5h6h7hw0hw1hw2hw3hw4hw5hw6hw7abs0abs1abs2abs3abs4abs5abs6abs7max01max23max45max67max0123max4567max_hwmax_xwsh16sh17sh18sh19sh20sh21sh22sh23sh24sh25sh26sh27sh28sh29sh30sh31sw16sw17sw18sw19sw20sw21sw22sw23sw24sw25sw26sw27sw28sw29sw30sw31y16y17y18y19y20y21y22y23y24y25y26y27y28y29y30y31scale_ue8m0scale_u8q16q17q18q19q20q21q22q23q24q25q26q27q28q29q30q31packed_lo_0packed_hi_0
packed64_0packed_lo_1packed_hi_1
packed64_1
fp4_offset	fp4_ptr_0	fp4_ptr_1x_ptr0x_ptr1x_ptr2x_ptr3r_ptr2r_ptr3w_ptr2w_ptr3x8x9x10x11x12x13x14x15r8r9r10r11r12r13r14r15w8w9w10w11w12w13w14w15h8h9h10h11h12h13h14h15hw8hw9hw10hw11hw12hw13hw14hw15abs8abs9abs10abs11abs12abs13abs14abs15max89maxabmaxcdmaxefmax89abmaxcdefmax_lomax_hisb                                                                                                                                                                                                                                                                                                                                                                    r    rA  zAddRMSNormFP4QuantKernel.kernel  s:   ( Y))++
aY((**
aF_
 $ :,N	i!m,, 	.	++--a0II*1--I#/!,Q/Or1155!!_,.$W_%=%=>> }**,,!!O$XV<<< " 
 
 !!O$XV<<< " 
 
 i1n-- 
	%%(@@@! &  B
 %%(@@@! &  B i1n-- 	?#33 .-!@AA   4    
 HH#33 .=)2L!MNN   4    
 **5A*>>H i!m,, 	%qyy	''!444I))+++I,,...I""$$$ '11_RD)+<==_RD)+<==_S(T9,=>>i1n-- 	B!%	4+Xa[N4HHH" " $R[2DEEE1i.AAB  $2J((**O' 
  
  
 . )X
 
 %..t44
$..t44
%%b))%%b))%%b))%%b))%%b))i1n-- 	.(22488J))"--D))"--D))"--D&t,,&t,, 4q)))&	!!q(  	CI*D$TBBBBI*D$TBBBBi1n-- 	CI*D$TBBBB	'')))	%%a((( 	$%%%$%%%((((&i1n-- 	YYr//FJJv CLL
 
 1*yw}t<< (?i!m,, 	 I,,...I""$$$$I.= A%7!; !! !!233 qA qA$w'@@111"(:"5K)**:;; kA"-i1n== y=")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC#*2lK"<L.L+M#N#ND#*2lK"<L.L+M#N#ND#*2lK"<L.L+M#N#ND#*2lK"<L.L+M#N#ND#*2lK"<L.L+M#N#ND#*2lK"<L.L+M#N#ND")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC")"\;?-J*K"L"LC#*2lK"<L.L+M#N#ND#*2lK"<L.L+M#N#ND#*2lK"<L.L+M#N#ND#*2lK"<L.L+M#N#ND#*2lK"<L.L+M#N#ND#*2lK"<L.L+M#N#ND!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C&.rllG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG +;W*D{*RK*2 +W_-E-E+ +K -<K,H,HM(-mfTll.J(K(KI !8 F F"2!3 &
  '1$2FGG G.4uQxx.?/=c

/JuUWyy.X.<uRyy.H-3uQxx-?
-;uSzz-I
040@4CU0U$.$>&043E&E%F&1E"II&=%> '2E!HH&<%= '2	%2 !0 7@? 3 3=F>6#9 :!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C(6r2r2r2rSU(V(VI(6 "BS#sC) )I )/y(9(9VBZZ(G6 )L L (H *5)9J&6 "Na1f$=
$J' 'G *'8<<<< &6 "NQ$6$D& &F &6 "NQ$6$DuQxx$O& &F &6 "NQ$6$D& &F &6 "NQ$6$DuQxx$O& &F &6b+%F%FF%5b+a:P%Q%QF-=f-E-ENBB-=f-E-ENBB-=f-E-ENBB-=f-E-ENBB-=f-E-ENBB-=f-E-ENBB&1':: ]O%*2r]]%*2r]]%*2r]]%*2r]]%*2r]]%*2r]]%*2r]]%*2r]]&/B&7&7&/B&7&7&/B&7&7&/B&7&7&/B&7&7&/B&7&7&/B&7&7&/B&7&7',Szz',Szz',Szz',Szz',Szz',Szz',Szz',Szz(-dD(9(9(-dD(9(9(-dD(9(9(-dD(9(9*/u*=*=*/u*=*=).w)@)@)4V)<)<*04-)?T)J)JB)?T)J)JB)?T)J)JB)?T)J)JB)?T)J)JB+A#t+L+LS+A#t+L+LS+A#t+L+LSS%0R%8%8%0R%8%8%0R%8%8%0R%8%8%0R%8%8%0R%8%8%0R%8%8%0R%8%8&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9'4S'9'9'4S'9'9'4S'9'9'4S'9'9'4S'9'9'4S'9'9'4S'9'9'4S'9'9(5dD(A(A(5dD(A(A(5dD(A(A(5dD(A(A*7u*E*E*7u*E*E)6w)H)H)<V)D)D*04-)A#t)L)LB)A#t)L)LB)A#t)L)LB)A#t)L)LB)A#t)L)LB+CC+N+NS+CC+N+NS+CC+N+NS +;W*D{*RK*2 +W_-E-E+ +K -<K,H,HM(-mfTll.J(K(KI !8 F F"2!3 &
  '1$2FGG G.4uQxx.?/=c

/JuUWyy.X.<uRyy.H-3uQxx-?
-;uSzz-I
040@4CU0U$.$>&043E&E%F&1E"II&=%> '2E!HH&<%= '2	%2 !0 7@? 3 3=F>6#9 :!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C(6r2r2r2rSU(V(VI(6 "BS#sC) )I )/y(9(9VBZZ(G6 )L L (H *5)9J&6 "Na1f$=
$J' 'G *'8<<<< #-i1n== mA")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC")"\;q;Q-Q*R"S"SC#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD#*2lK%PR))<S.S+T#U#UD "%tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B!$tc!1B"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C"&+"4C '/rllG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&E&EG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG&.w&F&FG
  '1$2Cw2NOO ".5.C.>{.K.K+0vd||1K+L+L,A+,N,N		 /?.H;.V.6$/1I1I/" /" 1@0L0L+01M+N+N %<M$J$J&6%7 !*
  '1$2FGG F.4uQxx.?/=c

/JuUWyy.X.<uRyy.H-3uQxx-?
-;uSzz-I
040@4CU0U$.$>&043E&E%F&1E"II&=%> '2E!HH&<%= '2	%2 !0 7?? 3 3=E>6#9 : "$iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C*8RRRQSUW*X*XK*8 "BS#sC+ +K +1*=*=*Kv +P P *J +9 #S#sCc3+ +K +9 #S#sCc3+ +K +1*=*=*Kv +P P *J *816)BV *aF *J )9Z(H(HI(8Z%PQ((=R(S(SI))Z@@@))Z@@@@
 &6 "NQ$6$D& &F &6 "NQ$6$DuQxx$O& &F &6 "NQ$6$DuRyy$P& &F &6 "NQ$6$DuRyy$P& &F &6 "NQ$6$D& &F &6 "NQ$6$DuQxx$O& &F &6 "NQ$6$DuRyy$P& &F &6 "NQ$6$DuRyy$P& &F &6b+%F%FF%5b+a:P%Q%QF%5b+b		:Q%R%RF%5b+b		:Q%R%RF-=f-E-ENBB-=f-E-ENBB/?/G/G,BC1A&1I1I.Cc3-=f-E-ENBB-=f-E-ENBB/?/G/G,BC1A&1I1I.Cc3-=f-E-ENBB-=f-E-ENBB/?/G/G,BC1A&1I1I.Cc3&1':: mP%*2r]]%*2r]]%*2r]]%*2r]]%*2r]]%*2r]]%*2r]]%*2r]]%*2r]]%*2r]]&+Coo&+Coo&+Coo&+Coo&+Coo&+Coo&/B&7&7&/B&7&7&/B&7&7&/B&7&7&/B&7&7&/B&7&7&/B&7&7&/B&7&7&/B&7&7&/B&7&7'0c':':'0c':':'0c':':'0c':':'0c':':'0c':':',Szz',Szz',Szz',Szz',Szz',Szz',Szz',Szz',Szz',Szz(-d(-d(-d(-d(-d(-d(-dD(9(9(-dD(9(9(-dD(9(9(-dD(9(9(-dD(9(9(-eU(;(;(-eU(;(;(-eU(;(;*/u*=*=*/u*=*=*/u*=*=*/u*=*=).w)@)@).w)@)@).vv)>)>)4V)<)<*04-)?T)J)JB)?T)J)JB)?T)J)JB)?T)J)JB)?T)J)JB+A#t+L+LS+A#t+L+LS+A#t+L+LS+A#t+L+LS+A#t+L+LS+A$+M+MS+A$+M+MS+A$+M+MS+A$+M+MS+A$+M+MS+A$+M+MSS%0R%8%8%0R%8%8%0R%8%8%0R%8%8%0R%8%8%0R%8%8%0R%8%8%0R%8%8%0R%8%8%0R%8%8&1#s&;&;&1#s&;&;&1#s&;&;&1#s&;&;&1#s&;&;&1#s&;&;&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9&1"b&9&9'23'<'<'23'<'<'23'<'<'23'<'<'23'<'<'23'<'<'4S'9'9'4S'9'9'4S'9'9'4S'9'9'4S'9'9'4S'9'9'4S'9'9'4S'9'9'4S'9'9'4S'9'9(5d(;(;(5d(;(;(5d(;(;(5d(;(;(5d(;(;(5d(;(;(5dD(A(A(5dD(A(A(5dD(A(A(5dD(A(A(5dD(A(A(5eU(C(C(5eU(C(C(5eU(C(C*7u*E*E*7u*E*E*7u*E*E*7u*E*E)6w)H)H)6w)H)H)6vv)F)F)<V)D)D*04-)A#t)L)LB)A#t)L)LB)A#t)L)LB)A#t)L)LB)A#t)L)LB+CC+N+NS+CC+N+NS+CC+N+NS+CC+N+NS+CC+N+NS+CD$+O+OS+CD$+O+OS+CD$+O+OS+CD$+O+OS+CD$+O+OS+CD$+O+OS
  '1$2Cw2NOO ".5.C.>{.K.K+0vd||1K+L+L,A+,N,N		 /?.H;.V.6$/1I1I/" /" 1@0L0L+01M+N+N %<M$J$J&6%7 !*
  '1$2FGG F.4uQxx.?/=c

/JuUWyy.X.<uRyy.H-3uQxx-?
-;uSzz-I
040@4CU0U$.$>&043E&E%F&1E"II&=%> '2E!HH&<%= '2	%2 !0 7?? 3 3=E>6#9 :!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB!#iB"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C"%	/C*8RRRQSUW*X*XK*8 "BS#sC+ +K +1*=*=*Kv +P P *J +9 #S#sCc3+ +K +9 #S#sCc3+ +K +1*=*=*Kv +P P *J *816)BV *aF *J )9Z(H(HI(8Z%PQ((=R(S(SI))Z@@@))Z@@@m 
qA qAr"   )NN)__name__
__module____qualname____doc__r   Numericr   boolstrr  staticmethodr	  r  r  r  tupler*  r2  r?   jitTensorr   r   rG  rA  LayoutShaper$  r"   r    r   r   9  s         "&#'1% 1%1% 1% 	1%
 1% 1% $J1% Dj1% 1% 1% 1%f c '/ s s    \2 C C    \ 2 2 2 2 2 \2 W W W W W W W \W6   	
 
   \"
S 
 
 
 
8 
X*
K*
 K*
 K	*

 K*
 K*
 k*
 *
 *
 *
 *
 X*
X 
[uAKuA KuA K	uA
 KuA KuA kuA uA uA ;uA *uA uA uA [uA uA uAr"   r   hidden_sizer   r   r  r  is_sf_swizzled_layoutc                    |rt           j        nt           j        }t          || ||||          }t	          j                    }t          j                            ||| fdd          }	t          j                            ||| fdd          }
t          j                            || fd          }t          j                            t           j        || dz  fdd          }rAt	          j                    }t          j                            t           j        |fd          }n2t          j                            t           j        || |z  fdd          }t          j        	                    d          }t          j                            t           j
        d	d
          }t	          j        ||	|
||||t          d          t          d          |d          dt          j        dt          j        dt          j        dt          j        dt          j        dt          j        dt          dt           ddffd}|S )z
    Get a compiled kernel closure that takes torch.Tensor directly.

    Uses TVM-FFI for efficient tensor passing without manual pointer construction.
    )r   r   r   r   r   r  r  rI  r   )stride_orderassumed_align)r  r   T)use_tvm_ffi_env_stream)rd   r   rd   ư>z--enable-tvm-ffi)optionsr<   rwysglobal_scaler9  r:  r   Nc                     r|                                 n|                                }|                    t          j                  }	 
| |||	||t          |          t          |                     dS )z;Runtime API that passes torch tensors directly via TVM-FFI.N)flatten
contiguousviewr   uint8r   r   )r<   r  r  r  r  r  r9  r:  s_tensory_uint8compiled_kernelr  s             r    
tensor_apiz(_get_compiled_kernel.<locals>.tensor_api  sx     #8K199;;;Q\\^^&&%%!HHCLL		
 		
 		
 		
 		
r"   )r   Float16BFloat16r   r?   sym_intruntimemake_fake_compact_tensorr
   make_fake_streamr   compiler   r   r  r   float)r  r   r   r  r  r  cutlass_dtype
kernel_objsym_mx_faker_fakew_fakey_fakesym_swizzled_sizes_fakestream_fakeglobal_scale_faker  r  s        `            @r    _get_compiled_kernelr
  X  s    (/DGOOG4DM)
-!  J LNNE \22{+&PS 3  F \22{+&PS 3  F \22~S 3  F \22{a/0vUX 3  F
  
 !LNN66M-/s 7 
 

 66MK:-.	 7 
 
 ,//t/LLK ==Q >  
 la"  O
<
<
 <
 <	

 <
 l
 
 
 

 
 
 
 
 
 
6 r"   r  Finputresidualweighty_fp4block_scaler  r:  c
           
         |                                  dk    }
|
rb| j        \  }}}|                     ||z  |                                          }|                    ||z  |                                          }n| }|}|j        \  }}| j        }||z  dk    s
J d            |dk    s
J d            |dv s
J d            |t
          j        k    }|r|n	|dk    rd	nd
}t          | j                  }|Z|
r-t          j	        |||dz  ft
          j
        | j                  }n+t          j	        ||dz  ft
          j
        | j                  }||d	k    rt
          j        nt
          j        }||z  }|	r8|dz   dz  }|dz   dz  }d}||z  |z  }t          j	        |f|| j                  }n@|
r t          j	        |||f|| j                  }nt          j	        ||f|| j                  }|
r7|                    ||z  d          }|	s|                    ||z  d          n|}n|}|}|&t          j        dt
          j        | j                  }t          ||||||	          } ||                                |                                |                                ||                    t
          j                  |                                ||           ||fS )a`  
    Fused Add + RMS normalization + FP4 quantization using CuTe-DSL.

    Computes: ``h = input + residual``, then ``y = RMSNorm(h) * weight``,
    optionally applies global scaling (``y = y / global_scale``),
    and finally quantizes ``y`` to FP4.

    Parameters
    ----------
    input : torch.Tensor
        Input tensor, shape ``(batch_size, hidden_size)`` or ``(batch_size, seq_len, hidden_size)``.
        Must be ``torch.float16`` or ``torch.bfloat16``.
    residual : torch.Tensor
        Residual tensor to add to input. Must have the same shape and dtype as ``input``.
    weight : torch.Tensor
        Weight tensor for RMSNorm, shape ``(hidden_size,)``.
        Must have the same dtype as input.
    y_fp4 : torch.Tensor, optional
        Output tensor for quantized values in FP4_E2M1 format with dtype
        ``torch.float4_e2m1fn_x2``.
        Shape must be ``(batch_size, hidden_size // 2)`` or matching 3D input.
        If ``None``, will be allocated automatically.
    block_scale : torch.Tensor, optional
        Output tensor for per-block scale factors.

        - If ``is_sf_swizzled_layout=False`` (default): row-major layout with shape
          ``(batch_size, hidden_size // block_size)`` or matching 3D input.
        - If ``is_sf_swizzled_layout=True``: swizzled layout for efficient tensor core
          access, with shape ``(batch_size * hidden_size // block_size,)`` flattened.
          The swizzle pattern uses 128x4 tiles where scales are arranged as:
          ``[m_tile][k_tile][outer_m (32)][inner_m (4)][inner_k (4)]``.

        Dtype should be ``torch.float8_e4m3fn`` for E4M3 format or ``torch.uint8``
        for UE8M0 format. If ``None``, will be allocated automatically.
    global_scale : torch.Tensor, optional
        Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``.
        If provided, the RMSNorm output is divided by this value before quantization:
        ``y = rmsnorm(h, w) / global_scale`` where ``h = input + residual``.
        This is used for NVFP4 format where a pre-computed global scale lifts
        per-block scales into optimal dynamic range.
        If ``None``, no global scaling is applied (equivalent to global_scale=1.0).
    eps : float
        Epsilon for numerical stability in RMSNorm. Default is ``1e-6``.
    block_size : int
        Number of elements per quantization block. Default is ``16``.

        - ``16``: NVFP4 format with E4M3 scale factors
        - ``32``: MXFP4 format with UE8M0 scale factors
    scale_format : str, optional
        Scale factor format: ``"e4m3"`` or ``"ue8m0"``.
        If ``None``, auto-selects based on ``block_size``:
        ``"e4m3"`` for block_size=16, ``"ue8m0"`` for block_size=32.
    is_sf_swizzled_layout : bool
        If ``True``, output scale factors in swizzled layout optimized for
        tensor core GEMM operations. The swizzle uses 128x4 tiles with the pattern:
        ``[m_tile_idx * k_tiles * 512 + k_tile_idx * 512 + outer_m * 16 + inner_m * 4 + inner_k]``
        where ``outer_m = row % 32``, ``inner_m = (row % 128) // 32``, etc.
        Default is ``False`` (row-major layout).

    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor]
        A tuple of ``(y_fp4, block_scale)``:

        - ``y_fp4``: Quantized FP4 values packed as uint8.
        - ``block_scale``: Per-block scale factors.

    Notes
    -----
    - Requires SM100+ (Blackwell) for FP4 quantization PTX intrinsics.
    - For block_size=16 (NVFP4): uses E4M3 scale factors (max value 448.0).
    - For block_size=32 (MXFP4): uses UE8M0 scale factors (power-of-2 scales).
    - FP4 E2M1 format has a max representable value of 6.0.
    re   r   z+hidden_size must be divisible by block_sizer!  zhidden_size must be >= 64r  zblock_size must be 16 or 32r   r  r  Nr   )r   r      r   r   r  rd   )dimr   r  r  r   r   float16r!   r   emptyfloat4_e2m1fn_x2r  float8_e4m3fnonesfloat32r
  )r  r  r  r  r  r  r:  r   r  r  is_3dBSr   input_2dresidual_2d
batch_sizer  r   r   actual_scale_formatr  scale_dtyper  num_m_tilesr  r  swizzled_sizey_fp4_2dblock_scale_2dr  s                                  r    add_rmsnorm_fp4quantr&    s   n IIKK1E +1a::a!eQ''2244mmAE1--88::&nJKE#q(((*W((("9!!!#@!!!u}$G$SjB6F6F77F   --J } 	KA{a'(,|  EE K[A-.,|  E  /'99EKKu?R 	 !,z 9  	%+3K014:KM'+5EM+ EL  KK  #k01% <   $k!67% <    %::a!eR((/DUKQUB'''+ 	 $ z!5=NNN% J J  EK((!!	 	 	 +r"   )r   r&  r!   )N)r   )NNNr  r   NF)Nr  	functoolsr   r   typingr   r   r   cutlass.cuter?   r   r   r   r   r   r	   r
   cutlass.cutlass_dslr   r   cutlass._mlir.dialectsr   api_loggingr   rk  r  SF_VEC_SIZEr  	lru_cacher   r   r  r!   Pointerr5   r;   r  rC   rG   rP   rV   rY   r^   ra   rn   rr   rw   rz   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r  	Constexprr   r   r   r   r   r   r   r   cacher  r
  r  r&  __all__r$  r"   r    <module>r3     s   0       " " " " " " " "         @ @ @ @ @ @ @ @ @ @ @ @ @ @ @ @ . . . . . . . . ' ' ' ' ' ' ( ( ( ( ( ( 	 R   * *3-3d: *c * * * ! *. DHT  l6;
   $  	  	l l $	 
   6 /3 F F FDK F F F F F 9=$   E &      @D   T[ % RW     &*t   g      "&4   ' G     ,0T    G g     ,0T    G g      T: : ::
6666)*: : : :: )-$    6    . 48T   V g    D .2t   g     < 48T   V g    > +/D    F V     -1d   6 f v      T   V f     '+   V       "&4   6 G    0 $(T   V f     /3   V       *.4   6 G    6 '+$ $ $$$
7G$ $ $ $@ (, $  $  $	 $ $
7G $  $  $  $F '+   V       -1d   6 f v      	* * *** 	* 		*
 	* 	* 	* 	* * * * *d   1# 6    
 -	-- k- 	-
 - - - 
-. ,-	,-,- k,- l	,-
  %,- ,- ,- ,- ,- 
,-^ ~ &s+ k	  %    
@ dk # $+    
4WA WA WA WA WA WA WA WA~( ooo o 	o
 o  o o o o od 
 "&'+(,#"'z z<zlz Lz <$	z
 $z ,%z 
z z *z  z 5<%&z z z zz  r"   