
    )`i,                        d Z ddlZddlZddlZddlmZmZ ddlZddlm	Z	 ddl
Z
ddlmZmZmZmZmZmZ ddlmZmZ ddlmZ ddlmZ d	Zd
ZdZdZ ej        d          dcdee
j        z  ez  dz  defd            Z edddde	j!        dedefd            Z"eddddede	j!        de	j!        deddf
d            Z#edddde	j$        de	j!        fd            Z%eddddedeeeeef         fd            Z&eddddedefd            Z'edddde	j$        dedefd             Z(edddd!edefd"            Z)edddd!ed#edefd$            Z*edddd!ed#edefd%            Z+edddd!ed#edefd&            Z,eddddedefd'            Z-edddd!ed#edefd(            Z.eddddedefd)            Z/eddddedefd*            Z0edddd!ed#edefd+            Z1eddddedefd,            Z2edddd-ed.edeeef         fd/            Z3edddd0ed.edeeef         fd1            Z4edddd!edefd2            Z5edddd3edefd4            Z6edddd5edefd6            Z7edddd7edefd8            Z8edddd9ed:ed;ed<ed=ed>ed?ed@edefdA            Z9e	j:        dddCej;        e         fdD            Z<e	j:        dedEedFe	j$        dGedef
dH            Z=e	j:        dedEedFe	j$        de	j!        dIej;        e         dGedefdJ            Z>e	j:        de	j?        dEe	j@        dKej;        e         dFe	j$        dIej;        e         dGefdL            ZAe	j:        dMe	j$        dNede	j$        fdO            ZB G dP dQ          ZCejD        dRedSedTeEdUedVedWeEdefdX            ZFe	 	 	 	 	 	 	 ded[e
j$        d\e
j$        d]e
j$        dz  d^e
j$        dz  d_e
j$        dz  d`eGdSedVedz  dWeEdee
j$        e
j$        f         fda            ZHg dbZIdS )fa  
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Fused RMSNorm + FP4 Quantization using CuTe-DSL
================================================

High-performance fused kernel for RMS normalization followed by FP4 quantization.
This is an alternative backend to cuDNN, using CuTe-DSL for maximum flexibility
and performance on SM100+ architectures.

Supports both NVFP4 and MXFP4 quantization formats.
    N)CallableTuple)Float32Int32Int64Uint32Uint64Uint8)Tdsl_user_op)llvm   )flashinfer_apig      @g      |@      )maxsizedevicereturnc                     t           j                                        sdS | t           j                                        } t           j                            |           }|j        dz  |j        z   S )a  Get the SM version of a CUDA device.

    Args:
        device: CUDA device to query. Can be an int (device index), torch.device,
            device string (e.g., 'cuda:0'), or None to use current device.

    Returns:
        SM version as an integer (e.g., 100 for SM100).
    P   N
   )torchcudais_availablecurrent_deviceget_device_propertiesmajorminor)r   propss     x/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/flashinfer/cute_dsl/rmsnorm_fp4quant.pyget_sm_versionr!   8   s`     :""$$ r~**,,J,,V44E;ek))    locipsmem_ptrpeer_cta_rank_in_clusterc                   |                      ||                                          }t          t          j        t          j                    ||                                gddddt          j        j                            S )z?Map smem pointer to address at another CTA rank in the cluster.r#   z$mapa.shared::cluster.u32 $0, $1, $2;=r,r,rFhas_side_effectsis_align_stackasm_dialect)	tointir_valuer   r   
inline_asmr   i32
AsmDialectAD_ATT)r&   r'   r$   r%   smem_ptr_i32s        r    set_block_rankr5   P   s}    
 >>cb>11::<<LEGG3<<>>?2" .	
 	
 	

 
 
r"   valmbar_ptrc          	         t          ||||                                          }t          ||||                                          }t          j        d||                     ||          |gddddt          j        j                   dS )zDStore Float32 value to shared memory on a remote CTA in the cluster.r#   NzIst.async.shared::cluster.mbarrier::complete_tx::bytes.f32 [$0], $1, [$2];zr,f,rTFr*   )r5   r/   r   r0   r2   r3   )r6   r&   r7   r'   r$   r%   remote_smem_ptr_i32remote_mbar_ptr_i32s           r    store_shared_remoter;   c   s     )*  hjj  )*  hjj  	O	cllsrl::<OPSO*     r"   xc                L    | j         t          j        || j        ||          z   S )z/Get pointer to element at coordinate in tensor.r#   )iteratorcutecrd2idxlayout)r<   coordr$   r%   s       r    elem_pointerrC      s&     :UAH#"EEEEEr"   base_ptrc                6   t          j        t           j                            t	          j                    t	          j                    t	          j                    t	          j                    g          t          |                               ||          gddddt           j        j	        ||	  	        }t          j
        t	          j                    |dg||          }t          j
        t	          j                    |dg||          }t          j
        t	          j                    |dg||          }t          j
        t	          j                    |d	g||          }t          |          t          |          t          |          t          |          fS )
z.Load 128 bits (4 x uint32) from global memory.r#   z(ld.global.v4.u32 {$0, $1, $2, $3}, [$4];z=r,=r,=r,=r,lFr+   r,   r-   r$   r%   r      r      )r   r0   
StructTypeget_literalr   r1   r   r/   r2   r3   extractvaluer   )rD   r$   r%   resultv0v1v2v3s           r    ld_global_v4_u32rQ      s<   
 _##QUWWaeggquww$HII	x	!	!cb	!	1	122O*
 
 
F 
	1577FQCSR	@	@	@B		1577FQCSR	@	@	@B		1577FQCSR	@	@	@B		1577FQCSR	@	@	@B"::vbzz6"::vbzz99r"   valuec          	          t          j        dt          |                               ||          t	          |                              ||          gddddt           j        j                   dS )zStore 64 bits to global memory.Nr#   zst.global.u64 [$0], $1;zl,lTFr*   )r   r0   r   r/   r	   r2   r3   )rD   rR   r$   r%   s       r    st_global_u64rT      s}     	O(OO$$$445MM""sr"22	
 	"O*     r"   tensoroffsetc                    | j         t          |          z   }t          j        t	          j                    |j        ||          }t          |          S )z2Get the memory address of tensor[offset] as Int64.r#   )r>   r   r   ptrtointr   i64llvm_ptrr   )rU   rV   r$   r%   elem_ptrptr_ints         r    get_ptr_as_int64r]      sC     v.HmAEGGX%6CBGGGG>>r"   ac                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z-Fast reciprocal using PTX rcp.approx.ftz.f32.r#   zrcp.approx.ftz.f32 $0, $1;z=f,fFr*   r   r   r0   r   f32r/   r2   r3   r^   r$   r%   s      r    rcp_approx_ftzrc      sf     EGGQZZ  SR 001(" .	
 	
 	

 
 
r"   bc                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )zJCompute min of two float32 values using PTX min.f32 (branchless clamping).r#   zmin.f32 $0, $1, $2;z=f,f,fFr*   r`   r^   rd   r$   r%   s       r    fmin_f32rg      s     EGGQZZ  SR 00'!**2E2E#RT2E2U2UV!" .	
 	
 	

 
 
r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z;Multiply two Half2 values element-wise: (a.x*b.x, a.y*b.y).r#   zmul.f16x2 $0, $1, $2;r)   Fr*   r   r   r0   r   r1   r/   r2   r3   rf   s       r    	half2_mulrj           EGGAYYCB//1C1CPR1C1S1ST#" .	
 	
 	

 
 
r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z=Multiply two BFloat2 values element-wise: (a.x*b.x, a.y*b.y).r#   zmul.bf16x2 $0, $1, $2;r)   Fr*   ri   rf   s       r    bfloat2_mulrm           EGGAYYCB//1C1CPR1C1S1ST$" .	
 	
 	

 
 
r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z<Half2 absolute value - clears sign bits of both fp16 values.r#   and.b32 $0, $1, 0x7FFF7FFF;=r,rFr*   ri   r<   r$   r%   s      r    habs2rs     f     EGGAYYCB//0)" .	
 	
 	

 
 
r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z-Half2 max - element-wise max of 2 fp16 pairs.r#   zmax.f16x2 $0, $1, $2;r)   Fr*   ri   rf   s       r    hmax2rv     rk   r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z1Extract max of 2 fp16 values in half2 as float32.r#   z
            {
                .reg .b16 h0, h1;
                .reg .f32 f0, f1;
                mov.b32 {h0, h1}, $1;
                cvt.f32.f16 f0, h0;
                cvt.f32.f16 f1, h1;
                max.f32 $0, f0, f1;
            }
            =f,rFr*   	r   r   r0   r   ra   r   r/   r2   r3   rr   s      r    hmax_to_f32rz   &  sj     EGGAYYCB//0	 " .!	
 	
 	
  r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )zABFloat16x2 absolute value - clears sign bits of both bf16 values.r#   rp   rq   Fr*   ri   rr   s      r    bfloat2_habs2r|   ?  rt   r"   c                   t          t          j        t          j                    t          |                               ||          t          |                              ||          gddddt          j        j                            S )z2BFloat16x2 max - element-wise max of 2 bf16 pairs.r#   zmax.bf16x2 $0, $1, $2;r)   Fr*   ri   rf   s       r    bfloat2_hmax2r~   O  rn   r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z3Extract max of 2 bf16 values in bfloat2 as float32.r#   ae  
            {
                .reg .b32 lo, hi;
                .reg .f32 f0, f1;
                and.b32 lo, $1, 0xFFFF;
                shr.b32 hi, $1, 16;
                shl.b32 lo, lo, 16;
                shl.b32 hi, hi, 16;
                mov.b32 f0, lo;
                mov.b32 f1, hi;
                max.f32 $0, f0, f1;
            }
            rx   Fr*   ry   rr   s      r    bfloat2_hmax_to_f32r   _  sj     EGGAYYCB//0 " .'	
 	
 	
  r"   h2scalec                P   t          j        t           j                            t	          j                    t	          j                    g          t          |                               ||          t          |                              ||          gddddt           j	        j
        ||	  	        }t          j        t	          j                    |dg||          }t          j        t	          j                    |dg||          }t          |          t          |          fS )z.Convert half2 to float2 AND multiply by scale.r#   z
        {
            .reg .b16 h0, h1;
            .reg .f32 f0, f1;
            mov.b32 {h0, h1}, $2;
            cvt.f32.f16 f0, h0;
            cvt.f32.f16 f1, h1;
            mul.f32 $0, f0, $3;
            mul.f32 $1, f1, $3;
        }
        	=f,=f,r,fFrF   r   rG   r   r0   rI   rJ   r   ra   r   r/   r   r2   r3   rK   )r   r   r$   r%   rL   f0f1s          r    half2_to_float2_scaledr     s    
 _##QUWWaegg$677					,	,genn.E.E#RT.E.U.UV
	 	O*'  F, 
	1577FQCSR	@	@	@B		1577FQCSR	@	@	@B2;;##r"   bf2c                P   t          j        t           j                            t	          j                    t	          j                    g          t          |                               ||          t          |                              ||          gddddt           j	        j
        ||	  	        }t          j        t	          j                    |dg||          }t          j        t	          j                    |dg||          }t          |          t          |          fS )z3Convert bfloat16x2 to float2 AND multiply by scale.r#   aU  
        {
            .reg .b32 lo, hi;
            .reg .f32 f0, f1;
            and.b32 lo, $2, 0xFFFF;
            shr.b32 hi, $2, 16;
            shl.b32 lo, lo, 16;
            shl.b32 hi, hi, 16;
            mov.b32 f0, lo;
            mov.b32 f1, hi;
            mul.f32 $0, f0, $3;
            mul.f32 $1, f1, $3;
        }
        r   FrF   r   rG   r   )r   r   r$   r%   rL   r   r   s          r    bfloat2_to_float2_scaledr     s    
 _##QUWWaegg$677			#"		-	-wu~~/F/F3SU/F/V/VW	 	O*-  F2 
	1577FQCSR	@	@	@B		1577FQCSR	@	@	@B2;;##r"   c                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )zAConvert float32 to E4M3 using native cvt.rn.satfinite.e4m3x2.f32.r#   a  
            {
                .reg .b16 fp8_pair;
                .reg .f32 zero;
                mov.f32 zero, 0f00000000;
                cvt.rn.satfinite.e4m3x2.f32 fp8_pair, zero, $1;
                cvt.u32.u16 $0, fp8_pair;
            }
            =r,fFr*   	r   r   r0   r   r1   r   r/   r2   r3   rb   s      r    cvt_f32_to_e4m3r     sj     EGGQZZ  SR 001 " .	
 	
 	
  r"   fp8_valc                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z3Convert FP8 E4M3 to float32 AND compute reciprocal.r#   a  
            {
                .reg .pred p_zero;
                .reg .u32 exp_u, mant_u;
                .reg .s32 exp_s;
                .reg .f32 exp_f, mant_f, fp8_float, result;

                setp.eq.u32 p_zero, $1, 0;
                and.b32 mant_u, $1, 7;
                shr.b32 exp_u, $1, 3;
                and.b32 exp_u, exp_u, 15;
                sub.s32 exp_s, exp_u, 7;
                cvt.rn.f32.s32 exp_f, exp_s;
                ex2.approx.f32 exp_f, exp_f;
                cvt.rn.f32.u32 mant_f, mant_u;
                fma.rn.f32 mant_f, mant_f, 0f3E000000, 0f3F800000;
                mul.f32 fp8_float, exp_f, mant_f;
                rcp.approx.ftz.f32 result, fp8_float;
                selp.f32 $0, 0f00000000, result, p_zero;
            }
            rx   Fr*   ry   )r   r$   r%   s      r    fp8_e4m3_to_f32_and_rcpr     sj     EGGG__%%#"%556* " .7	
 	
 	
  r"   max_valc                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )aS  
    Convert float32 max value to UE8M0 scale factor.

    UE8M0 is unsigned 8-bit exponent-only format:
    - value = 2^(ue8m0 - 127)
    - ue8m0 = ceil(log2(max_val)) + 127

    Uses lg2.approx.f32 for fast log2 approximation.
    Uses cvt.rpi (round towards positive infinity, i.e., ceiling).
    Returns value clamped to [0, 255].
    r#   a  
            {
                .reg .pred p_zero, p_neg, p_ovf;
                .reg .f32 log2_val;
                .reg .s32 exp_int, result;

                // Check for zero/negative
                setp.le.f32 p_zero, $1, 0f00000000;

                // Compute ceil(log2(max_val)) using cvt.rpi (round towards +inf)
                lg2.approx.f32 log2_val, $1;
                cvt.rpi.s32.f32 exp_int, log2_val;

                // Add bias and clamp to [0, 255]
                add.s32 result, exp_int, 127;
                setp.lt.s32 p_neg, result, 0;
                setp.gt.s32 p_ovf, result, 255;
                selp.s32 result, 0, result, p_neg;
                selp.s32 result, 255, result, p_ovf;
                selp.s32 $0, 0, result, p_zero;
            }
            r   Fr*   r   )r   r$   r%   s      r    cvt_f32_to_ue8m0r     sl     EGGW&&32&667, " .9	
 	
 	
  r"   	ue8m0_valc                    t          t          j        t          j                    t          |                               ||          gddddt          j        j                            S )z
    Convert UE8M0 to output_scale for MXFP4 quantization.

    UE8M0 value = 2^(ue8m0 - 127)
    Returns 1 / 2^(ue8m0 - 127) = 2^(127 - ue8m0)
    r#   a  
            {
                .reg .pred p_zero;
                .reg .s32 neg_exp;
                .reg .f32 neg_exp_f, result;

                // Check for zero
                setp.eq.u32 p_zero, $1, 0;

                // Compute 2^(127 - ue8m0) = 1 / 2^(ue8m0 - 127)
                sub.s32 neg_exp, 127, $1;
                cvt.rn.f32.s32 neg_exp_f, neg_exp;
                ex2.approx.f32 result, neg_exp_f;
                selp.f32 $0, 0f00000000, result, p_zero;
            }
            rx   Fr*   ry   )r   r$   r%   s      r    ue8m0_to_output_scaler   :  sl     EGGI''CB'778  " .-	
 	
 	
  r"   rM   rN   rO   rP   v4v5v6v7c                   t          t          j        t          j                    t          |                               ||	          t          |                              ||	          t          |                              ||	          t          |                              ||	          t          |                              ||	          t          |                              ||	          t          |                              ||	          t          |                              ||	          gddddt          j        j                            S )zMConvert eight float32 values to eight E2M1 (4-bit) values packed into uint32.r#   a  
            {
                .reg .b8 byte0, byte1, byte2, byte3;
                cvt.rn.satfinite.e2m1x2.f32 byte0, $2, $1;
                cvt.rn.satfinite.e2m1x2.f32 byte1, $4, $3;
                cvt.rn.satfinite.e2m1x2.f32 byte2, $6, $5;
                cvt.rn.satfinite.e2m1x2.f32 byte3, $8, $7;
                mov.b32 $0, {byte0, byte1, byte2, byte3};
            }
            z=r,f,f,f,f,f,f,f,fFr*   r   )
rM   rN   rO   rP   r   r   r   r   r$   r%   s
             r    cvt_e2m1x8_f32r   c  s0    EGG$$$44$$$44$$$44$$$44$$$44$$$44$$$44$$$44		 !" .3	
 	
 	
  r"       widthc           	      F   t          j        t          | t          j                            rt          j        | j        | j                  }|                    |            t          j	        t          j
        | j                            D ]}t          ||         ||          ||<   |                                S t          j	        t          t          j        |                              D ]0} || t          j                            | d|z                      } 1| S )z8Reduce across threads in a warp using butterfly shuffle.rG   )rV   )cutlass
const_expr
isinstancer?   	TensorSSAmake_rmem_tensorshapedtypestorerange_constexprsizewarp_reduceloadintmathlog2archshuffle_sync_bfly)r6   opr   resis        r    r   r     s     *S$.99:: 	#CIsy99		#(39)=)=>> 	4 	4A QU33CFFxxzz(TYu-=-=)>)>?? 	K 	KA"S$)55c!q&5IIJJCC
r"   r   reduction_bufferinit_valc                 f   t           j                                        }t           j                                        }t          j        |j        d                   }||z  }||z  }|dk    r| |||f<   t           j                                         |}	||k     r
|||f         }	t          |	|          S )z:Block reduction across multiple warps using shared memory.rG   r   )r?   r   lane_idxwarp_idxr   r   barrierr   )
r6   r   r   r   r   r   warps_per_rowrow_idxcol_idxblock_reduce_vals
             r    block_reducer     s     y!!##Hy!!##HI.4Q788M-'G&G1}}-0')*I-+GX,=>',,,r"   	cluster_nc           	      >   t           j                                        }t           j                                        }t           j                                        }|j        d         }	|j        d         d         }
||
z  }||
z  }|dk    rct           j                                        5  |	|
z  }||z  dz  }t           j                            ||           ddd           n# 1 swxY w Y   ||k     r%t          | t          ||||ff          ||           t           j        
                    |d           |
|z  }t          j        |d          }|}t          j        |          D ]$}||dz  z   }||k     r |||||f                   }%t          ||          S )z6Cluster reduction across multiple CTAs using mbarrier.r   rG      N)r'   )phaser   )r?   r   block_idx_in_clusterr   r   r   	elect_onembarrier_arrive_and_expect_txr;   rC   mbarrier_waitceil_divr   r   r   )r6   r   r   r7   r   r   cta_rank_in_clusterr   r   rows_per_blockr   r   r   	num_warpsexpected_bytes	num_totalnum_iterr   r   idxs                       r    cluster_reducer     s    )88::y!!##Hy!!##H%+A.N$*1-a0M-'G&G 1}}Y  "" 	N 	N&6I&2Q6NI33HnMMM	N 	N 	N 	N 	N 	N 	N 	N 	N 	N 	N 	N 	N 	N 	N ))Gg?R5S+TUU%-		
 	
 	
 	
 	IHA... 	)I}Y++H$X.. T TR??!r"24DWc\4RSS',,,s   ).C##C'*C'threads_per_rowc                    |                      ||d          }t          j        j        t          j        t          j        j        t          j        j        i|         }t          |d          }	t          |||	          }
t          |dz  d          }t          j        |dk    p|dk              r>t          j        |dk              rt          |
|||          S t          |
|||||          S |
S )z,Row reduction with optional cluster support.r   )r   reduction_profiler   )r   rG   )reducer?   ReductionOpADDoperatoraddMAXr   fmaxminr   maxr   r   r   r   )r<   r   r   r   r7   r   r   	local_valwarp_op
warp_widthwarp_valr   s               r    
row_reducer     s     h!DDI 	hldin 	
G _b))J9gZ@@@H2-q11M-!+<y1}== i1n-- 	'3CXNNN!'#3Xy(   r"   tXcXlimitc           
         t          j        t          j        t          j        | ddg          t          j        | dg          t          j        | dg          ft          j        | dg          ddf          t          j                  }t	          j        |j        d                   D ]P}t	          j        |j        d                   D ].}t          j        | d|fd|f         d         |          ||d|f<   /Q|S )z,Create predicate tensor for bounds checking.r   rG   )moder   stride)	r?   r   make_layoutr   r   Booleanr   r   	elem_less)r   r   tXpXrest_vrest_ks        r    predicate_kr     s     	$aV,,,	$aS)))	$aS)))
 Id!---q!4	
 	
 	
 	
 
D )$*Q-88  -djm<< 	 	F&*na[!V+,Q/' 'DF"##	 Kr"   c                       e Zd ZdZ	 	 d$dej        dedededededz  d	edz  fd
Z	e
dedej        dedefd            Ze
dedefd            Ze
dedefd            Ze
dedededefd            Ze
dededededef
d            ZdefdZej        dej        dej        dej        dej        dej        dedefd             Zej        dej        dej        dej        dej        dej        deded!ej        d"ej        fd#            ZdS )%RMSNormFP4QuantKernela  
    Fused RMSNorm + FP4 Quantization Kernel.

    Key optimizations:
    1. Half2/BFloat2 SIMD for max-abs computation
    2. Branchless scale clamping via fmin_f32
    3. Cluster synchronization for large H dimensions
    4. Direct 128-bit vectorized global loads
    Nr   H
block_sizeoutput_swizzledis_fp16
sm_versionscale_formatc                 Z   || _         || _        || _        || _        || _        ||nt                      | _        ||dk    rdnd| _        n|| _        |dv sJ d|             | j        dv s
J d            |                     ||| j                  | _	        || j	        z  | _
        |                     | j
                  | _        |                     | j
                  | _        | j        | j        z  | _        t!          | j        dz  d          | _        |j        d	z  }t&          d	z  |z  | _        t!          d| j
        | j        z  | j        z   dz
  | j        z            | _        | j        | j        z  | j        z  | _        ||z  | _        |r||z  }	|	d
z   dz  | _        d| _        d S d S )Nr   ue8m0e4m3r   r   z!block_size must be 16 or 32, got )r   r   z&scale_format must be 'e4m3' or 'ue8m0'rG      rH   r      )r   r   r   r   r   r!   r   r   _compute_cluster_nr   	H_per_cta_compute_threads_per_rowr   _compute_num_threadsnum_threadsr   r   r   r   	COPY_BITSvec_sizenum_vec_blockscols_per_tilenum_sf_blocks_per_rownum_k_tilesk_tile_stride)
selfr   r   r   r   r   r   r   
elem_bytesnum_col_vecss
             r    __init__zRMSNormFP4QuantKernel.__init__=  s    
$.(2(>**NDTDT +5+;+;D ,D X%%%'W:'W'W%%% $55554 655
 00E4?KK dn,  $<<T^LL44T^DD".$2FF !5!;Q?? [A%
!Q*4!^t},t/CCaG#$
 

 "]T-@@4CWW &'*_"  	%
?L ,q 0Q6D!$D	% 	%r"   r   c                    |dk     rdS t           j                            t           j                                                  }|j        }|j        dz  }dD ]2}| |z  dk    rt                              | ||          }||k    r|c S 3dS )a'  Compute optimal cluster size based on H and device shared memory.

        Dynamically determines the minimum cluster_n that fits within the
        device's shared memory limit, making it compatible with different
        GPU architectures (e.g., SM100 with 228KB vs SM120 with 128KB).
        Z   rG   r   )rG   r   r   r   r   r   r   )r   r   r   r   shared_memory_per_block_optinr   r   _estimate_smem_bytes)r   r   r   r   max_smem_bytes	elem_sizer   smem_neededs           r    r   z(RMSNormFP4QuantKernel._compute_cluster_ny  s     ??1
001J1J1L1LMM<K1$	) 	! 	!I9}!!/DD9i K n,,     - rr"   r   c                 V    | dk    rdS | dk    rdS | dk    rdS | dk    rdS | dk    rdS d	S )
z3Compute optimal threads per row based on H per CTA.@   r   r   r   i   r   i    @      r   s    r    r   z.RMSNormFP4QuantKernel._compute_threads_per_row  sW     ??1#2$2$2%33r"   c                     | dk    rdndS )z3Compute total threads per block based on H per CTA.r  r   r  r  r  s    r    r  z*RMSNormFP4QuantKernel._compute_num_threads  s      5((ssc1r"   r   r  c                 `   | |z  }t                               |          }t                               |          }||z  }t          |dz  d          }t          dz  |z  }t          d||z  |z   dz
  |z            }	||	z  |z  }
||
z  |z  }|dk    rd|z  ||z  dz  z   S |||z  |z  dz  z   dz   S )zEstimate shared memory bytes needed for given configuration.

        This is used to dynamically determine cluster_n based on device
        shared memory limits.
        r   rG   r   r   r   )r   r   r  r   r  )r   r   r  r   r   r  r   r   r  r  r  
tile_bytess               r    r  z*RMSNormFP4QuantKernel._estimate_smem_bytes  s     N	/HHSS+@@KK$7Or1155>Y.	X%7!;O
 
 !>1OC#m3i?
>>z>N]$BQ$FFF  > JQ NNQRRRr"   r   r   r  r  c                 <    | |f||ff}||z  df|||z  | z  ff}||fS )zBCreate Thread-Value layout for coalesced vectorized memory access.rG   r  )r   r   r  r  r   r   s         r    _make_tv_layoutz%RMSNormFP4QuantKernel._make_tv_layout  sO     n-~&

 &*^h6HI
 f}r"   c                     | j         | j        z  | j        j        dz  z  }| j        dk    r| j         | j        z  dz  }n| j         | j        z  | j        z  dz  }| j        dk    rdnd}||z   |z   S )z$Calculate shared memory requirement.r   rG   r   r   )r   r  r   r   r   r   )r
  r  reduction_bytes
mbar_bytess       r    _smem_size_in_bytesz)RMSNormFP4QuantKernel._smem_size_in_bytes  s     (4+==AQUVAVW
 >Q"1D4FFJOO #d&884>IAM 
 .1,,QQ!
O+j88r"   mXmWmYmSmGlobalScaleMepsc	                    |                      | j        | j        | j        | j                  \  }	}
t          j        |	|
          }| j        | j        f}|                     |||||||||	  	        	                    t          j
        || j                  | j        dg| j        ddgt          j        | j        dk              r
d| j        dgnd|                                 |           dS )a~  Host function to launch the kernel.

        Takes tensors directly via TVM-FFI.
        - mX: Input tensor, shape (M, H), row-major
        - mW: Weight tensor, shape (H,)
        - mY: Output FP4 tensor, shape (M, H // 2), row-major (packed)
        - mS: Scale factor tensor, shape depends on swizzle mode
        - mGlobalScale: Global scale tensor, shape (1,), float32
        r   rG   N)gridblockclustersmemstream)r  r   r   r  r  r?   r   r  kernellaunchr   r   r  r   r   r#  )r
  r$  r%  r&  r'  r(  r)  r*  r0  tv_shape	tv_stride	tv_layouttiler_mns                r    __call__zRMSNormFP4QuantKernel.__call__  s   . #22 M	
 
) $Xi@@@	');< 	BBL!S)XNNUU-4#677K#Q*!$.1"455Q**))++ 	V 	
 	
 	
 	
 	
r"   r5  r6  c
                 
5   t           j                                        \  }
}}t           j                                        \  }}}| j        }| j        }| j        }| j        }| j        }t          j
        |dk              r%t           j                                        d         }nt          j
        d          }|j        d         d         }t          |dz  d          }|	d         }|
|z  }|
|z  }t          t          t                              }t          j                                        }|                    |j        t          j        |	d          d          }t          j
        |dk              r4|                    t          t          j        ||f          d          }d	}nO|                    t          t          j        |||ff          d          }|                    t.          d
          }t          j
        |dk              r|
dk    r t           j                            |d           t           j                                         t           j                                         t           j                                         t          j        |j                  }t          j        ||	||f          }t          j        ||	||f          }t          j        |j        t          j        |	d         fd                    } t          j         |j!        |           }!t          j        |!|	d|f          }t          j"        t           j#        j$        %                                |j        tL                    }"t          j'        |"||	          }#|#(                    |
          }$|$)                    |          }%|$*                    |          }&|$)                    |          }'t          j+        |%          }(tY          |'|          })|'d         }*|*d         |k     }+|+rt          j-        |"|%|&|)           t           j        .                                 t           j        /                    d           t          j0        |&|(           |(1                                2                    t                    },|,|,z  }-tg          |-t           j4        j5        ||||t          d                    }.|.|z  }/t           j6        7                    |/|z   d          }0|d         }1t          j
        |dk              r=t           j                                         t           j                                         nt           j        8                                 ||z  |z   }2|2|k     r||z   dz
  |z  }3ts          |3          D ]}4||4|z  z   }5|5|k     r|5|z  }6t          j
        |dk              rtu          ||2|z  |6z             }7tu          ||2|z  |6z   tw          d          z             }8tu          ||6          }9tu          ||6tw          d          z             }:ty          |7          \  };}<}=}>ty          |8          \  }?}@}A}Bty          |9          \  }C}D}E}Fty          |:          \  }G}H}I}Jt          j
        |          rt{          |;|C          }Kt{          |<|D          }Lt{          |=|E          }Mt{          |>|F          }Nt{          |?|G          }Ot{          |@|H          }Pt{          |A|I          }Qt{          |B|J          }Rt}          |K          }St}          |L          }Tt}          |M          }Ut}          |N          }Vt}          |O          }Wt}          |P          }Xt}          |Q          }Yt}          |R          }Zt          |S|T          }[t          |U|V          }\t          |W|X          }]t          |Y|Z          }^t          |[|\          }_t          |]|^          }`t          |_|`          }at          |a          }b|b|0z  }ct          |K|0          \  }d}et          |L|0          \  }f}gt          |M|0          \  }h}it          |N|0          \  }j}kt          |O|0          \  }l}mt          |P|0          \  }n}ot          |Q|0          \  }p}qt          |R|0          \  }r}snt          |;|C          }Kt          |<|D          }Lt          |=|E          }Mt          |>|F          }Nt          |?|G          }Ot          |@|H          }Pt          |A|I          }Qt          |B|J          }Rt          |K          }St          |L          }Tt          |M          }Ut          |N          }Vt          |O          }Wt          |P          }Xt          |Q          }Yt          |R          }Zt          |S|T          }[t          |U|V          }\t          |W|X          }]t          |Y|Z          }^t          |[|\          }_t          |]|^          }`t          |_|`          }at          |a          }b|b|0z  }ct          |K|0          \  }d}et          |L|0          \  }f}gt          |M|0          \  }h}it          |N|0          \  }j}kt          |O|0          \  }l}mt          |P|0          \  }n}ot          |Q|0          \  }p}qt          |R|0          \  }r}s|1|cz  |z  }tt          |tt          t                              }tt          |t          }ut          |ut          d          z            }vt          |u          |1z  }wt          j
        | jM                  r|5tw          d          z  }x|2tw          d          z  tw          d          z  }y|2tw          d          z  }z|5tw          d          z  }{|2tw          d          z  }|| jN        | jO        z  }}|||}z  |{| jO        z  z   |ztw          d          z  z   |ytw          d          z  z   |xz   }~|v||~<   n|v||2|5f<   |d|wz  }|e|wz  }|f|wz  }|g|wz  }|h|wz  }|i|wz  }|j|wz  }|k|wz  }|l|wz  }|m|wz  }|n|wz  }|o|wz  }|p|wz  }|q|wz  }|r|wz  }|s|wz  }t          ||||||||          }t          ||||||||          }t          |          t          d          z  t          |          z  }|6dz  }tu          ||2|dz  z  |z             }t          ||           Htu          ||2|z  |6z             }tu          ||2|z  |6z   tw          d          z             }tu          ||6          }tu          ||6tw          d          z             }tu          ||2|z  |6z   tw          d          z             }tu          ||2|z  |6z   tw          d          z             }tu          ||6tw          d          z             }tu          ||6tw          d          z             }ty          |          \  }}}}ty          |          \  }}}}ty          |          \  }}}}ty          |          \  }}}}ty          |          \  }}}}ty          |          \  }}}}ty          |          \  }}}}ty          |          \  }}}}t          j
        |          r0t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t{          ||          }t}          |          }t}          |          }t}          |          }t}          |          }t}          |          }t}          |          }t}          |¦          }t}          |æ          }t}          |Ħ          }t}          |Ŧ          }t}          |Ʀ          }t}          |Ǧ          }t}          |Ȧ          }t}          |ɦ          }t}          |ʦ          }t}          |˦          }t          ||ͦ          }t          ||Ϧ          }t          ||Ѧ          }t          ||Ӧ          }t          ||ݦ          }t          ||ߦ          }t          ||          }t          ||զ          }t          ||צ          }t          ||٦          }t          ||ۦ          }t          ||          }t          ||          }t          ||          }t          ||          }at          |a          }b|b|0z  }ct          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  } }t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}	n.t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          ||          }t          |          }t          |          }t          |          }t          |          }t          |          }t          |          }t          |¦          }t          |æ          }t          |Ħ          }t          |Ŧ          }t          |Ʀ          }t          |Ǧ          }t          |Ȧ          }t          |ɦ          }t          |ʦ          }t          |˦          }t          ||ͦ          }t          ||Ϧ          }t          ||Ѧ          }t          ||Ӧ          }t          ||ݦ          }t          ||ߦ          }t          ||          }t          ||զ          }t          ||צ          }t          ||٦          }t          ||ۦ          }t          ||          }t          ||          }t          ||          }t          ||          }at          |a          }b|b|0z  }ct          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  } }t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}t          ||0          \  }}	t          j
        | jS        dk              rG|c|z  }tt          |t          }
t          |
t          d          z            }t          |
          }wnk|1|cz  |z  }tt          |tt          t                              }tt          |t          }ut          |ut          d          z            }t          |u          |1z  }wt          j
        | jM                  r|5tw          d          z  }x|2tw          d          z  tw          d          z  }y|2tw          d          z  }z|5tw          d          z  }{|2tw          d          z  }|| jN        | jO        z  }}|||}z  |{| jO        z  z   |ztw          d          z  z   |ytw          d          z  z   |xz   }~|||~<   n|||2|5f<   ||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }t          ||||||||          }t          ||||||||          }t          |          t          d          z  t          |          z  }|6dz  }tu          ||2|dz  z  |z             }t          ||           ||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }| |wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }||wz  }|	|wz  }t          ||||||||          }t          ||||||||          }t          |          t          d          z  t          |          z  }|6dz   dz  }tu          ||2|dz  z  |z             }t          ||           d	S d	S )a  Device kernel with cluster synchronization for large H.

        mGlobalScale contains the global scale value. The kernel reads it and
        computes 1/global_scale, which is multiplied with rstd to apply:
        y = x * rstd * w / global_scale = rmsnorm(x, w) / global_scale
        rG   r   r   rG   r   )orderr   )byte_alignmentr   N)	num_elems)r   r   )num_bits_per_copy)r   ))r   r   r   r   )predg        T)fastmathr      r   r      r   )Vr?   r   
thread_idx	block_idxr   r   r  r   r   r   r   r   r   rc   r   FLOAT4_E2M1_MAXutilsSmemAllocatorallocate_tensorelement_typemake_ordered_layoutr   allocate_arrayr   mbarrier_initmbarrier_init_fencecluster_arrive_relaxedcluster_waitmake_identity_tensor
local_tileprependrA   make_tensorr>   make_copy_atomnvgpucpasync	CopyG2SOpr  make_tiled_copy	get_slicepartition_Spartition_Dmake_fragment_liker   copycp_async_commit_groupcp_async_wait_groupautovec_copyr   tor   r   r   r   rsqrtr   ranger]   r   rQ   rj   rs   rv   rz   r   rm   r|   r~   r   r   rg   FLOAT8_E4M3_MAXr   r
   r   r   r   r  r	  r   r	   rT   r   r   r   (  r
  r$  r%  r&  r'  r(  r)  r*  r5  r6  tidx_bidxr   r   r  r   r   	cluster_yr   r   r   lane_in_rowrow_in_blockfp4_max_rcpr/  sXr   r7   idXgXcXmW_expanded_layoutmW_2dcopy_atom_load_asynctiled_copy_load
thr_copy_XtXgXtXsXr   tXrXr   	row_coordrow_in_boundsr<   x_sqsum_sqmean_sqrstdglobal_scale_valactual_row_idxnum_sf_per_threadsf_itersf_idxblock_startx_ptr0x_ptr1w_ptr0w_ptr1x0x1x2x3x4x5x6x7w0w1w2w3w4w5w6w7xw0xw1xw2xw3xw4xw5xw6xw7abs0_h2abs1_h2abs2_h2abs3_h2abs4_h2abs5_h2abs6_h2abs7_h2max01_h2max23_h2max45_h2max67_h2
max0123_h2
max4567_h2	max_xw_h2max_xwmax_absy0y1y2y3y4y5y6y7y8y9y10y11y12y13y14y15scale_floatscale_fp8_u32	scale_fp8	inv_scaleinner_k_idxinner_m_idxouter_m_idx
k_tile_idx
m_tile_idxm_tile_strideswizzled_offsetq0q1q2q3q4q5q6q7q8q9q10q11q12q13q14q15	packed_lo	packed_hipacked64
out_offsetout_ptr	x_ptr0_c0	x_ptr1_c0	w_ptr0_c0	w_ptr1_c0	x_ptr0_c1	x_ptr1_c1	w_ptr0_c1	w_ptr1_c1x0_c0x1_c0x2_c0x3_c0x4_c0x5_c0x6_c0x7_c0w0_c0w1_c0w2_c0w3_c0w4_c0w5_c0w6_c0w7_c0x0_c1x1_c1x2_c1x3_c1x4_c1x5_c1x6_c1x7_c1w0_c1w1_c1w2_c1w3_c1w4_c1w5_c1w6_c1w7_c1xw0_c0xw1_c0xw2_c0xw3_c0xw4_c0xw5_c0xw6_c0xw7_c0xw0_c1xw1_c1xw2_c1xw3_c1xw4_c1xw5_c1xw6_c1xw7_c1
abs0_c0_h2
abs1_c0_h2
abs2_c0_h2
abs3_c0_h2
abs4_c0_h2
abs5_c0_h2
abs6_c0_h2
abs7_c0_h2
abs0_c1_h2
abs1_c1_h2
abs2_c1_h2
abs3_c1_h2
abs4_c1_h2
abs5_c1_h2
abs6_c1_h2
abs7_c1_h2max01_c0_h2max23_c0_h2max45_c0_h2max67_c0_h2max0123_c0_h2max4567_c0_h2	max_c0_h2max01_c1_h2max23_c1_h2max45_c1_h2max67_c1_h2max0123_c1_h2max4567_c1_h2	max_c1_h2y0_c0y1_c0y2_c0y3_c0y4_c0y5_c0y6_c0y7_c0y8_c0y9_c0y10_c0y11_c0y12_c0y13_c0y14_c0y15_c0y0_c1y1_c1y2_c1y3_c1y4_c1y5_c1y6_c1y7_c1y8_c1y9_c1y10_c1y11_c1y12_c1y13_c1y14_c1y15_c1scale_ue8m0scale_u8s                                                                                                                                                                                                                                                                              r    r1  zRMSNormFP4QuantKernel.kernel  s   & Y))++
aY((**
aF_
 $ :,N	 i!m,, 	.	++--a0II*1--I#/!,Q/Or1155!! _,. %W_%=%=>>
 }**,, !!O$XV<<< " 
 
 i1n-- 	?#33 .-!@AA   4    
 HH#33 .=)2L!MNN   4    
 **5A*>>H
 i!m,, 	%qyy	''!444I))+++I,,...I""$$$
 '11 _RD)+<==_S(T9,=>> "\It'!tDDD
 
  .@AAO8a^
 
  $2J((**O' 
  
  
 . )X
 
 %..t44
 %%b))%%b))%%b)) &t,,
 4q)))&	!!q(
  	CI*D$TBBBB	'')))	%%a(((
 	$%%%IIKKNN7## 1u CLL
 
 1*yw}t<< (? i!m,, 	 I,,...I""$$$$I .=A &7!; !! !!233 E9 E9$w'@@111"(:"5K )**:;; {9!1"nq6H;6V!W!W!1 2[ @588 K" " "2"k!B!B!1"kE!HH6L!M!M *:&)A)ABB)9&)A)ABB *:&)A)ABB)9&)A)ABB #-g66 OK"+B"3"3C"+B"3"3C"+B"3"3C"+B"3"3C"+B"3"3C"+B"3"3C"+B"3"3C"+B"3"3C ',CjjG&+CjjG&+CjjG&+CjjG&+CjjG&+CjjG&+CjjG&+CjjG',Wg'>'>H',Wg'>'>H',Wg'>'>H',Wg'>'>H).x)B)BJ).x)B)BJ(-j*(E(EI%0%;%;F&,tmG &<C%F%FFB%;C%F%FFB%;C%F%FFB%;C%F%FFB%;C%F%FFB'=c4'H'HHC'=c4'H'HHC'=c4'H'HHC"-b""5"5C"-b""5"5C"-b""5"5C"-b""5"5C"-b""5"5C"-b""5"5C"-b""5"5C"-b""5"5C '4C&8&8G&3C&8&8G&3C&8&8G&3C&8&8G&3C&8&8G&3C&8&8G&3C&8&8G&3C&8&8G'4Wg'F'FH'4Wg'F'FH'4Wg'F'FH'4Wg'F'FH)6x)J)JJ)6x)J)JJ(5j*(M(MI%8%C%CF&,tmG &>c4%H%HFB%=c4%H%HFB%=c4%H%HFB%=c4%H%HFB%=c4%H%HFB'?T'J'JHC'?T'J'JHC'?T'J'JHC '7&@;&N&.{GO<T<T&U&U(7(D(D$)-&,,*F$G$G	
 4MBBEUU " #-d.BCC C*0588*;K+9E#JJ+F5QS99*TK*8599*DK)/588);J)75::)EJ,0,<t?Q,QM *] :",t/A"A!B"-b		"9!: #.a"8!9 #.	!. , 3<B//9BB~v56
  )^)^)^)^)^)^)^)^)^)^!Io!Io!Io!Io!Io!Io$22r2r2r2r$R$R	$22r3S#sTW$X$X	$*9$5$5$CviGXGX#X%0A%5
"2!q& 9J F# # &gx8888 %5 2[ @% %	 %5 2[ @588 K% %	 %5R$E$E	$4RuQxx9O$P$P	$4 2[ @599 L% %	 %5 2[ @599 L% %	 %5RuRyy9P$Q$Q	$4RuRyy9P$Q$Q	5Ei5P5P2ueU5Ei5P5P2ueU5Ei5P5P2ueU5Ei5P5P2ueU5Ei5P5P2ueU5Ei5P5P2ueU5Ei5P5P2ueU5Ei5P5P2ueU"-g66 ST%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F%.ue%<%<F).vJ).vJ).vJ).vJ).vJ).vJ).vJ).vJ).vJ).vJ).vJ).vJ).vJ).vJ).vJ).vJ*/
J*G*GK*/
J*G*GK*/
J*G*GK*/
J*G*GK,1+{,K,KM,1+{,K,KM(-m](K(KI*/
J*G*GK*/
J*G*GK*/
J*G*GK*/
J*G*GK,1+{,K,KM,1+{,K,KM(-m](K(KI(-i(C(CI%0%;%;F&,tmG+A&$+O+OLE5+A&$+O+OLE5+A&$+O+OLE5+A&$+O+OLE5+A&$+O+OLE5-CFD-Q-QNFF-CFD-Q-QNFF-CFD-Q-QNFF+A&$+O+OLE5+A&$+O+OLE5+A&$+O+OLE5+A&$+O+OLE5+A&$+O+OLE5-CFD-Q-QNFF-CFD-Q-QNFF-CFD-Q-QNFFF%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F%0%>%>F)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J)6v)>)>J*7
J*O*OK*7
J*O*OK*7
J*O*OK*7
J*O*OK,9+{,S,SM,9+{,S,SM(5m](S(SI*7
J*O*OK*7
J*O*OK*7
J*O*OK*7
J*O*OK,9+{,S,SM,9+{,S,SM(5m](S(SI(5i(K(KI%8%C%CF&,tmG+CFD+Q+QLE5+CFD+Q+QLE5+CFD+Q+QLE5+CFD+Q+QLE5+CFD+Q+QLE5-Efd-S-SNFF-Efd-S-SNFF-Efd-S-SNFF+CFD+Q+QLE5+CFD+Q+QLE5+CFD+Q+QLE5+CFD+Q+QLE5+CFD+Q+QLE5-Efd-S-SNFF-Efd-S-SNFF-Efd-S-SNFF
 #-d.?7.JKK *1K*?K*:;*G*GK',[6$<<-G'H'HH(=k(J(JII +;W*D{*RK*2 +W_-E-E+ +K -<K,H,HM',]VD\\-I'J'JH !8 F F"2!3 &
 #-d.BCC B*0588*;K+9E#JJ+F5QS99*TK*8599*DK)/588);J)75::)EJ,0,<t?Q,QM *] :",t/A"A!B"-b		"9!: #.a"8!9 #.	!. , 3;B//9AB~v56 #Y."Y."Y."Y."Y."Y."Y."Y."Y."Y.$y0$y0$y0$y0$y0$y0$22r2r2r2r$R$R	$22r3S#sTW$X$X	$*9$5$5$CviGXGX#X%0A%5
"2!q& 9J F# # &gx888 #Y."Y."Y."Y."Y."Y."Y."Y."Y."Y.$y0$y0$y0$y0$y0$y0$22r2r2r2r$R$R	$22r3S#sTW$X$X	$*9$5$5$CviGXGX#X&1B&61%<
"2!q& 9J F# # &gx888W E9 E9r"   )NN)__name__
__module____qualname____doc__r   Numericr   boolstrr  staticmethodr   r   r  r  tupler  r#  r?   jitTensorr   r   r7  r1  LayoutShaper  r"   r    r   r   2  s        " "&#':% :%:% :% 	:%
 :% :% $J:% Dj:% :% :% :%x c '/ s s    \2 C C    \ 2 2 2 2 2 \2 S S S S S S S \S6   	
 
   \"9S 9 9 9 9$ 
X(
K(
 K(
 K	(

 K(
 k(
 (
 (
 (
 (
 X(
T 
[}	9K}	9 K}	9 K	}	9
 K}	9 k}	9 }	9 }	9 ;}	9 *}	9 }	9 }	9 [}	9 }	9 }	9r"   r   hidden_sizer   r   r   r   is_sf_swizzled_layoutc                    |rt           j        nt           j        }t          || ||||          }t	          j                    }t          j                            ||| fdd          }	t          j                            || fd          }
t          j                            t           j        || dz  fdd          }rAt	          j                    }t          j                            t           j        |fd          }n2t          j                            t           j        || |z  fdd          }t          j        	                    d          }t          j                            t           j
        d	d
          }t	          j        ||	|
|||t          d          t          d          |d
  
        dt          j        dt          j        dt          j        dt          j        dt          j        dt          dt           ddffd}|S )z
    Get a compiled kernel closure that takes torch.Tensor directly.

    Uses TVM-FFI for efficient tensor passing without manual pointer construction.
    )r   r   r   r   r   r   r   r9  r   )stride_orderassumed_align)ri  r   T)use_tvm_ffi_env_stream)rG   r   rG   ư>z--enable-tvm-ffi)optionsr<   wysglobal_scaler)  r*  r   Nc                     
r|                                 n|                                }|                    t          j                  } 	| ||||t          |          t          |                     dS )z;Runtime API that passes torch tensors directly via TVM-FFI.N)flatten
contiguousviewr   uint8r   r   )r<   rm  rn  ro  rp  r)  r*  s_tensory_uint8compiled_kernelrf  s            r    
tensor_apiz(_get_compiled_kernel.<locals>.tensor_api  su     #8K199;;;Q\\^^&&%%!HHCLL	
 	
 	
 	
 	
r"   )r   Float16BFloat16r   r?   sym_intruntimemake_fake_compact_tensorr
   make_fake_streamr   compiler   r   rb  r   float)re  r   r   r   r   rf  cutlass_dtype
kernel_objsym_mx_fakew_fakey_fakesym_swizzled_sizes_fakestream_fakeglobal_scale_fakery  rx  s        `           @r    _get_compiled_kernelr    sW    (/DGOOG4DM '
-!  J LNNE \22{+&PS 3  F \22~S 3  F \22{a/0vUX 3  F
  
 !LNN66M-/s 7 
 

 66MK:-.	 7 
 
 ,//t/LLK ==Q >  
 la"  O
<
<
 <
 <	

 l
 
 
 

 
 
 
 
 
 
2 r"   rk  Finputweighty_fp4block_scalerp  r*  c	           	      H   |                                  dk    }	|	r7| j        \  }
}}|                     |
|z  |                                          }n| }|j        \  }}| j        }||z  dk    s
J d            |dk    s
J d            |dv s
J d            |t
          j        k    }|r|n	|dk    rd	nd
}t          | j                  }|Z|	r-t          j	        |
||dz  ft
          j
        | j                  }n+t          j	        ||dz  ft
          j
        | j                  }||d	k    rt
          j        nt
          j        }||z  }|r8|dz   dz  }|dz   dz  }d}||z  |z  }t          j	        |f|| j                  }n@|	r t          j	        |
||f|| j                  }nt          j	        ||f|| j                  }|	r7|                    |
|z  d          }|s|                    |
|z  d          n|}n|}|}|&t          j        dt
          j        | j                  }t          ||||||          } ||                                |                                ||                    t
          j                  |                                ||           ||fS )a  
    Fused RMS normalization with FP4 quantization using CuTe-DSL.

    Computes: ``y = RMSNorm(input) * weight``, optionally applies global scaling
    (``y = y / global_scale``), then quantizes ``y`` to FP4.

    Parameters
    ----------
    input : torch.Tensor
        Input tensor, shape ``(batch_size, hidden_size)`` or ``(batch_size, seq_len, hidden_size)``.
        Must be ``torch.float16`` or ``torch.bfloat16``.
    weight : torch.Tensor
        Weight tensor for RMSNorm, shape ``(hidden_size,)``.
        Must have the same dtype as input.
    y_fp4 : torch.Tensor, optional
        Output tensor for quantized values in FP4_E2M1 format with dtype
        ``torch.float4_e2m1fn_x2``.
        Shape must be ``(batch_size, hidden_size // 2)`` or matching 3D input.
        If ``None``, will be allocated automatically.
    block_scale : torch.Tensor, optional
        Output tensor for per-block scale factors.

        - If ``is_sf_swizzled_layout=False`` (default): row-major layout with shape
          ``(batch_size, hidden_size // block_size)`` or matching 3D input.
        - If ``is_sf_swizzled_layout=True``: swizzled layout for efficient tensor core
          access, with shape ``(batch_size * hidden_size // block_size,)`` flattened.
          The swizzle pattern uses 128x4 tiles where scales are arranged as:
          ``[m_tile][k_tile][outer_m (32)][inner_m (4)][inner_k (4)]``.

        Dtype should be ``torch.float8_e4m3fn`` for E4M3 format or ``torch.uint8``
        for UE8M0 format. If ``None``, will be allocated automatically.
    global_scale : torch.Tensor, optional
        Global scale factor tensor of shape ``(1,)`` with dtype ``torch.float32``.
        If provided, the RMSNorm output is divided by this value before quantization:
        ``y = rmsnorm(x, w) / global_scale``. This is used for NVFP4 format where
        a pre-computed global scale lifts per-block scales into optimal dynamic range.
        If ``None``, no global scaling is applied (equivalent to global_scale=1.0).
    eps : float
        Epsilon for numerical stability in RMSNorm. Default is ``1e-6``.
    block_size : int
        Number of elements per quantization block. Default is ``16``.

        - ``16``: NVFP4 format with E4M3 scale factors
        - ``32``: MXFP4 format with UE8M0 scale factors
    scale_format : str, optional
        Scale factor format: ``"e4m3"`` or ``"ue8m0"``.
        If ``None``, auto-selects based on ``block_size``:
        ``"e4m3"`` for block_size=16, ``"ue8m0"`` for block_size=32.
    is_sf_swizzled_layout : bool
        If ``True``, output scale factors in swizzled layout optimized for
        tensor core GEMM operations. The swizzle uses 128x4 tiles with the pattern:
        ``[m_tile_idx * k_tiles * 512 + k_tile_idx * 512 + outer_m * 16 + inner_m * 4 + inner_k]``
        where ``outer_m = row % 32``, ``inner_m = (row % 128) // 32``, etc.
        Default is ``False`` (row-major layout).

    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor]
        A tuple of ``(y_fp4, block_scale)``:

        - ``y_fp4``: Quantized FP4 values packed as uint8.
        - ``block_scale``: Per-block scale factors.

    Notes
    -----
    - Requires SM100+ (Blackwell) for FP4 quantization PTX intrinsics.
    - For block_size=16 (NVFP4): uses E4M3 scale factors (max value 448.0).
    - For block_size=32 (MXFP4): uses UE8M0 scale factors (power-of-2 scales).
    - FP4 E2M1 format has a max representable value of 6.0.
    rH   r   z+hidden_size must be divisible by block_sizer  zhidden_size must be >= 64r   zblock_size must be 16 or 32r   r   r   Nr   )r   r      r   r   r   rG   )dimr   rt  rs  r   r   float16r!   r   emptyfloat4_e2m1fn_x2ru  float8_e4m3fnonesfloat32r  )r  r  r  r  rp  r*  r   r   rf  is_3dBSr   input_2d
batch_sizere  r   r   actual_scale_formatr   scale_dtyper  num_m_tilesr  r	  swizzled_sizey_fp4_2dblock_scale_2dry  s                                r    rmsnorm_fp4quantr    s_   f IIKK1E +1a::a!eQ''2244&nJKE#q(((*W((("9!!!#@!!! u}$G$SjB6F6F77F   --J } 	KA{a'(,|  EE K[A-.,|  E  /'99EKKu?R 	 !,z 9  	%+3K014:KM'+5EM+ EL  KK  #k01% <   $k!67% <    %::a!eR((/DUKQUB'''+ 	 $ z!5=NNN & J JEK((!!   +r"   )r   r!   r  )N)r   )NNNrk  r   NF)Jr[  	functoolsr   r   typingr   r   r   cutlass.cuter?   r   r   r   r   r   r	   r
   cutlass.cutlass_dslr   r   cutlass._mlir.dialectsr   api_loggingr   rD  rc  SF_VEC_SIZEr  	lru_cacher   r   r^  r!   Pointerr5   r;   rb  rC   rQ   rT   r]   rc   rg   rj   rm   rs   rv   rz   r|   r~   r   r   r   r   r   r   r   r   ra  	Constexprr   r   r   r   r   r   r   r   cacher]  r  r  r  __all__r  r"   r    <module>r     s
   2       " " " " " " " "         @ @ @ @ @ @ @ @ @ @ @ @ @ @ @ @ . . . . . . . . ' ' ' ' ' ' ( ( ( ( ( ( 	 R   * *3-3d: *c * * * ! *. DHT  l6;
   $  	  	l l $	 
   6 /3 F F FDK F F F F F  T: : ::
6666)*: : : :0 9=$   E &      @D   T[ % RW     &*t   g      ,0T    G g    ( +/D    F V     -1d   6 f v    (  T   V f     '+   V       "&4   6 G    0 $(T   V f     /3   V       *.4   6 G    @ '+$ $ $$$
7G$ $ $ $@ (, $  $  $	 $ $
7G $  $  $  $P '+   w     . 48T      V  g        P .2t + + +g + + + + +\ 48T      V  g        P  	* * *** 	* 		*
 	* 	* 	* 	* * * * *d   1# 6    
 -	-- k- 	-
 - - - 
-. 0-	0-0- k0- l	0-
  %0- 0- 0- 0- 0- 
0-f ~ &s+ k	  %    
J dk # $+    
4_9 _9 _9 _9 _9 _9 _9 _9N jjj j 	j
 j  j j j j jZ  "&'+(,#"'u u<uLu <$u $	u
 ,%u 
u u *u  u 5<%&u u u up  r"   