
    )`iN                     p   d Z ddlZddlmZ ddlmZ ddlmZmZ ddl	Z	ddl
mZ ddlmZ d	d
lmZ d	dlmZmZ ej        d             Zde	j        dee	j                 dedededededee	j        e	j        e	j        e	j        e	j        e	j        f         fdZde	j        de	j        de	j        de	j        de	j        de	j        dedededededdfdZde	j        de	j        de	j        d e	j        de	j        d!e	j        d"e	j        dededdfd#Zd$eddfd%Zdedefd&Zdedefd'Zd(e	j        d)ee	j                 d*ee	j                 d+e	j        dedededed,ededee	j        e	j        e	j        e	j        e	j        e	j        e	j        e	j        f         fd-Ze G d. d/                      Z G d0 d1          ZdS )2a3  
Copyright (c) 2025 by FlashInfer team.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
    N)	dataclass)SimpleNamespace)OptionalTuple   )gen_comm_alltoall_module)register_custom_op   )Mapping)MnnvlMemoryMnnvlConfigc                     t                                                      t          dg           dt          j        dt
          t          j                 dt          dt          dt          dt          d	t          d
t          t          j        t          j        t          j        t          j        t          j        t          j        f         ffd            } t          dddg          dt          j        dt          j        dt          j        dt          j        dt          j        dt          j        dt          dt          dt          dt          d	t          d
d ffd            }t          ddg          dt          j        dt          j        dt          j        dt          j        dt          j        dt          j        dt          j        dt          d	t          d
d ffd            }t          dg           dt          d
d ffd            }t          dg           d	t          d
t          ffd             }t          d!g           d	t          d
t          ffd"            }t          d#g           d$t          j        d%t
          t          j                 d&t
          t          j                 d't          j        dt          dt          d	t          dt          d(t          dt          d
t          t          j        t          j        t          j        t          j        t          j        t          j        t          j        t          j        f         ffd)            }t          | ||||||*          S )+Nz$flashinfer::moe_comm_prepare_indices)mutates_argsgathered_target_rank_idsreal_rank_token_count_cum_summax_token_count_per_rankexpert_counttop_kep_rankep_sizereturnc                 *   | j         }t          ||          }t          j        ||z  |t          j                  }	t          j        |f|t          j                  }
t          j        ||z  |t          j                  }t          j        ||t          j                  }t          j        ||z  |t          j                  }t          j        ||z  |t          j                  }                    | ||	|
|||||||||           |	|
||||fS )N)devicedtype)r   maxtorchemptyintmoe_comm_prepare_indices)r   r   r   r   r   r   r   r   max_send_ranks_per_tokenlocal_gather_indicessend_rank_count_cum_sumsend_rank_local_indicesrecv_rank_count_cum_sumrecv_rank_local_indicesbackward_recv_rank_local_indicemodules                  s/home/jaya/work/projects/VOICE-AGENT/VIET/agent-env/lib/python3.11/site-packages/flashinfer/comm/trtllm_alltoall.pyr   z:get_comm_alltoall_module.<locals>.moe_comm_prepare_indices"   sY   ( *0#&ug#6#6 ${%/uy 
  
  
 #(+JvUY#
 #
 #
 #(+%(@@)#
 #
 #

 #(+wei"X"X"X"'+%/uy#
 #
 #
 +0+%(@@)+
 +
 +
'
 	''$) ####+$	
 	
 	
  !####+
 	
    zflashinfer::moe_local_gatherlocal_expert_idslocal_scalesrecv_rank_cum_sumr!   gathered_expert_idsgathered_scalesc                 F                         | |||||||||	|
           d S N)moe_local_gather)r,   r!   r-   r.   r*   r+   r   r   r   r   r   r'   s              r(   r1   z2get_comm_alltoall_module.<locals>.moe_local_gatherd   sH    " 	 $	
 	
 	
 	
 	
r)   zflashinfer::moe_commoutputinputsend_rank_cum_sumsend_indicesrecv_indicesall_workspacesc	                 B    	                     | ||||||||	  	         d S r0   )moe_comm)
r3   r4   r5   r2   r,   r6   r7   r   r   r'   s
            r(   r9   z*get_comm_alltoall_module.<locals>.moe_comm   s@     	
	
 
	
 
	
 
	
 
	
r)   z'flashinfer::set_moe_max_usable_sm_countmax_sm_countc                 2                         |            d S r0   )set_moe_max_usable_sm_count)r:   r'   s    r(   r<   z=get_comm_alltoall_module.<locals>.set_moe_max_usable_sm_count   s     	**<88888r)   z/flashinfer::get_moe_commworkspace_size_per_rankc                 .                         |           S r0   )#get_moe_commworkspace_size_per_rankr   r'   s    r(   r>   zEget_comm_alltoall_module.<locals>.get_moe_commworkspace_size_per_rank   s     99'BBBr)   z3flashinfer::get_moe_prepare_workspace_size_per_rankc                 .                         |           S r0   )'get_moe_prepare_workspace_size_per_rankr?   s    r(   rA   zIget_comm_alltoall_module.<locals>.get_moe_prepare_workspace_size_per_rank   s     ==gFFFr)   zflashinfer::moe_prepareexperts_idsscalesexperts_statics	workspace
slot_countc
                    t           j        | j        d}
t          j        ||z  |	ffi |
}t          j        |ffi |
}t          j        |ffi |
}t          j        ||z  ffi |
}t          j        ||z  ffi |
}t	          ||	          }t          j        ||z  ffi |
}t          j        ||z  ffi |
}t          j        ||z  ffi |
}t          j        ||z  ffi |
}|-t          j        ||z  |	ft           j        |
d                   }nd }|t          j        ||ffi |
}nd }                    | ||||||||||||||||||||	           ||||||||fS )Nr   r   r   )r   int32r   r   r   float32moe_prepare)rB   rC   rD   rE   r   r   r   r   rF   r   attrsprepared_local_expert_idsr"   r$   gather_recv_rank_indicesrecv_rank_indicesr    !gather_backward_recv_rank_indicesbackward_recv_rank_indicesgather_send_rank_indicessend_rank_indicesprepared_local_scalesgathered_expert_staticsr'   s                          r(   rK   z-get_comm_alltoall_module.<locals>.moe_prepare   ss   2  +1CDD$)K%/7%
 %
;@%
 %
! #(+wj"B"BE"B"B"'+wj"B"BE"B"B#(;%/1$
 $
5:$
 $
  "K)AG)K(MWWQVWW#&w#6#6 ,1K%(@@B-
 -
FK-
 -
) &+[%(@@B&
 &
FK&
 &
" $);%(@@B$
 $
FK$
 $
  "K%(@@B
 
FK
 
 $)K)G3U;mX% % %!! %)!&&+k7L2I&S&SU&S&S##&*#%##$-&$!#$+	
 	
 	
0 &!#$#$-#	
 		
r)   )r   r1   r9   r<   r>   rA   rK   )	r   build_and_loadr	   r   Tensorr   r   r   r   )r   r1   r9   r<   r>   rA   rK   r'   s          @r(   get_comm_alltoall_modulerX      s   %''6688F.  <
"',<
'/'=<
 #&<
 	<

 <
 <
 <
 
	
<
 <
 <
 <
 <
	 <
| &(.9  
 <
#l
 #\
 	

  ,
 l
 #&
 
 
 
 
 

 
 
 
 
	 
6 Z  
|
 <
 l
 	

 !<
 l
 
 
 
 

 
 
 
 
	 
. 1  99	9 9 9 9 9	 9
 9  CC	C C C C C	 C
 =  GG	G G G G G	 G
 !  ]
\]
&]
 "%,/]
 <	]

 #&]
 ]
 ]
 ]
 ]
 ]
 
		
]
 ]
 ]
 ]
 ]
	 ]
~ !9)$?,O0W   r)   r   r   r   r   r   r   r   r   c           	      P    t                                          | ||||||          S r0   )rX   r   )r   r   r   r   r   r   r   s          r(   r   r   '  s7     $%%>> %   r)   r,   r!   r-   r.   r*   r+   c                 \    t                                          | |||||||||	|
           d S r0   )rX   r1   )r,   r!   r-   r.   r*   r+   r   r   r   r   r   s              r(   r1   r1   =  sM     //     r)   r3   r4   r5   r2   r6   r7   c	                 X    t                                          | ||||||||	  	         d S r0   )rX   r9   )	r3   r4   r5   r2   r,   r6   r7   r   r   s	            r(   r9   r9   Y  sG     ''
 
 
 
 
r)   r:   c                 H    t                                          |            d S r0   )rX   r<   )r:   s    r(   r<   r<   q  s$     ::<HHHHHr)   c                 D    t                                          |           S r0   )rX   r>   r   s    r(   r>   r>   w  s     $%%II'RRRr)   c                 D    t                                          |           S r0   )rX   rA   r^   s    r(   rA   rA   }  s     $%%MMgVVVr)   rB   rC   rD   rE   rF   c
                 V    t                                          | |||||||||	
  
        S r0   )rX   rK   )
rB   rC   rD   rE   r   r   r   r   rF   r   s
             r(   rK   rK     s@    * $%%11   r)   c                       e Zd ZU ej        ed<   ej        ed<   ej        ed<   ej        ed<   ej        ed<   ej        ed<   eed<   dS )	MoEAlltoallInfor!   send_rank_count_cumsumr#   recv_rank_count_cumsumr%    backward_recv_rank_local_indiceslocal_token_allocation_countN)__name__
__module____qualname__r   rW   __annotations__r    r)   r(   rb   rb     sr         ,&&&!L((("\)))!L((("\)))&+l222"%%%%%%r)   rb   c                   V   e Zd ZU dZeed<   dZeed<   dZej	        ed<   dZ
ej	        ed<   dZeed<   ed"dedee         fd	            Ze	 d"dedee         fd
            Zedej	        dedefd            Zedej	        dej	        deej	                 dej	        dedededededefd            Zedej	        dej	        dej	        dej	        dededededefd            Zedej	        dedej	        dedef
d            Zedej	        dedej	        dededed efd!            ZdS )#MnnvlMoeNmoe_workspacemoe_prepare_workspacemoe_workspace_tensormoe_prepare_workspace_tensormoe_mappingmappingconfigc                    t           j        &| t           j        k    s
J d            t           j        S | t           _        t	          | j                  }|rt          j        | |           t          | |          t           _        t           j                            t          j
                  t           _        t           j        S Nz"only one moe mapping supported now)rm   rn   rr   rp   r>   tp_sizer   set_comm_from_configas_torch_strided_tensorr   uint64rs   rt   workspace_size_per_ranks      r(   get_moe_workspaceszMnnvlMoe.get_moe_workspaces  s    !-h22224X22200&"Ego"V"V 	>,Wf===!,W6M!N!N(0(>(V(VL)
 )
% ,,r)   c                 n   t           j        &| t           j        k    s
J d            t           j        S t          | j                  }|rt          j        | |           t          | |          t           _        t           j                            t          j
                  t           _        t           j        S rv   )rm   rq   rr   rA   rw   r   rx   ro   ry   r   rz   r{   s      r(   get_moe_prepare_workspacez"MnnvlMoe.get_moe_prepare_workspace  s     0<h22224X22288"IO#
 #
  	>,Wf===)4W>U)V)V&*BB5<PP 	- 44r)   token_selected_expertsr   r   c                 @    ||z  dk    s
J d            ||z  }| |z  }|S )Nr   z+expert_count should be divisible by ep_sizerk   )r   r   r   expert_per_ranktoken_target_rank_idss        r(   compute_target_rank_idzMnnvlMoe.compute_target_rank_id  sB     g%***9 +** ''1 6/ I$$r)   
expert_idsrC   expert_staticsrE   r   r   rF   r   c
                     t          | |||||||||	
  
        \  }
}}}}}}}||z  }d }t          |||||||          }||
||fS r0   )rK   rb   )r   rC   r   rE   r   r   r   r   rF   r   prepared_local_expertsrT   local_send_rank_count_cumsumlocal_send_rank_indiceslocal_recv_rank_count_cumsumlocal_recv_rank_indices backward_local_recv_rank_indicesrU   rf   r!   alltoall_infos                        r(   -mnnvl_moe_alltoallv_prepare_without_allgatherz6MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather  s    , $
 
		
"!(#(#,# (@''I$#' (#(#,(
 
 "!#	
 	
r)   r   real_rank_token_count_cumsumr-   r.   c	                 z   t          | ||||||          \  }	}
}}}}||z  }t          j        ||t          j        t          j        d                    }t          j        ||t          j        t          j        d                    }t          ||	|||||||||           t          |	|
|||||          }|||fS )NcudarH   )r   r   r   rI   r   rJ   r1   rb   )r   r   r-   r.   r   r   r   r   r   r!   rc   r#   rd   r%   re   rf   r*   r+   r   s                      r(   mnnvl_moe_alltoallv_preparez$MnnvlMoe.mnnvl_moe_alltoallv_prepare  s   & %$($
 
	
 "#"#, (@''I$ ;(+<''	
 
 
 {(-<''	
 
 
 	" $	
 	
 	
 ( "#"#,(
 
 .<<r)   xr   c                    |                                  dk    s
J d            t          j        |j        | j        d         | j        t          j        d                    }t          | |j        |j	        ||j
        |j        |||	  	         |S )Nr   z)only 2D tensor supported, please reshape.r
   r   rH   )dimr   r   rf   shaper   r   r9   rc   r#   rd   r%   )r   r   rE   r   r   output_tensors         r(   mnnvl_moe_alltoallvzMnnvlMoe.mnnvl_moe_alltoallvb  s     uuww!|||H|||6GAJ'<''	
 
 
 	0101
	
 
	
 
	
 r)   token_countc                    |                                  dk    s
J d            t          j        ||z  | j        d         | j        t          j        d                    }t          | |j        |j        ||j	        |j
        |||	  	         t          j        |                    ||| j        d                   dd          S )Nr   z$2D tensor supported, please reshape.r
   r   rH   F)r   keepdim)r   r   zerosr   r   r   r9   rd   r%   rc   re   sumreshape)r   r   rE   r   r   r   r   r   s           r(   mnnvl_moe_alltoallv_combinez$MnnvlMoe.mnnvl_moe_alltoallv_combine~  s     uuww!|||C|||%175<PVCWCW
 
 
 	010:
	
 
	
 
	
 y!!+uagajAAqRW
 
 
 	
r)   r0   )rg   rh   ri   rn   r   rj   ro   rp   r   rW   rq   rr   r   staticmethodr   r   r}   r   r   r   r   r   rb   r   r   rk   r)   r(   rm   rm     s        !%M;%%%)-;---)-%,---15 %,555K- -G -Xk5J - - - \- :>5 55"*;"75 5 5 \5" % %%<?%JM% % % \% 5
L5
5
 !.5
 <	5

 #&5
 5
 5
 5
 5
 5
 5
 5
 \5
n B="',B=&+lB= #\B= 	B=
 #&B= B= B= B= B= B= B= \B=H <& < 	
    \6 
<
&
 <
 	

 
 
 
 
 
 \
 
 
r)   rm   ) __doc__	functoolsdataclassesr   typesr   typingr   r   r   jitr   utilsr	   rs   r   mnnvlr   r   cacherX   rW   r   r   r1   r9   r<   r>   rA   rK   rb   rm   rk   r)   r(   <module>r      s"         ! ! ! ! ! ! ! ! ! ! ! ! " " " " " " " "  * * * * * * & & & & & &       + + + + + + + + E E EP#l#+EL#9 " 	
    	L%,elEL%,V   ,|,  \	
 l , "     
   8<| , L	
 | , L   
   0II	I I I ISSS S S SWWW W W W  U\"  el+  |	 
 "            	L	L	L	L	L	L	L	L	       F & & & & & & & &h
 h
 h
 h
 h
 h
 h
 h
 h
 h
r)   