2525
2626import torch
2727import torch_npu
28+ from vllm .config import get_current_vllm_config
2829from vllm .distributed .parallel_state import get_ep_group
2930
3031from vllm_ascend .distributed .parallel_state import get_mc2_group
@@ -100,15 +101,34 @@ def __init__(self, **kwargs):
100101 self .need_extra_args = (
101102 get_ascend_device_type () == AscendDeviceType ._910_93 )
102103
103- # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
104- self .a3_need_extra_args = \
105- get_ascend_device_type () == AscendDeviceType ._910_93
106104 # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
107105 # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
108106 # improve communication performance.
109107 self .need_expert_scale = is_hierarchical_communication_enabled ()
110108 self .with_quant = False
111109
110+ # Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
111+ # dispatch & combine operators with different input num_tokens per rank.
112+ vllm_config = get_current_vllm_config ()
113+ scheduler_config = vllm_config .scheduler_config
114+ compilation_config = vllm_config .compilation_config
115+ speculative_config = vllm_config .speculative_config
116+ tp_size = vllm_config .parallel_config .tensor_parallel_size
117+ uniform_decode_query_len = 1 if not speculative_config else \
118+ 1 + speculative_config .num_speculative_tokens
119+ decode_max_num_seqs = getattr (scheduler_config ,
120+ 'decode_max_num_seqs' , 0 )
121+ max_num_reqs = max (scheduler_config .max_num_seqs ,
122+ decode_max_num_seqs )
123+ if compilation_config .cudagraph_capture_sizes :
124+ max_num_tokens = compilation_config .max_cudagraph_capture_size
125+ else :
126+ max_num_tokens = min (
127+ max_num_reqs * uniform_decode_query_len , 512 )
128+ num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1 ) // tp_size
129+ mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
130+ self .global_bs = mc2_tokens_capacity * self .ep_world_size
131+
112132 def get_dispatch_mc2_kwargs (
113133 self ,
114134 hidden_states : torch .Tensor ,
@@ -130,7 +150,7 @@ def get_dispatch_mc2_kwargs(
130150 "expert_shard_type" : 0 ,
131151 "shared_expert_rank_num" : 0 ,
132152 "moe_expert_num" : moe_expert_num ,
133- "global_bs" : 0 ,
153+ "global_bs" : self . global_bs ,
134154 "expert_token_nums_type" : 0 ,
135155 }
136156
@@ -147,10 +167,6 @@ def get_dispatch_mc2_kwargs(
147167 "tp_world_size" : 1 ,
148168 "tp_rank_id" : 0 ,
149169 })
150- if self .a3_need_extra_args and self .enable_dispatch_v2 :
151- stage1_kwargs .update ({
152- "x_active_mask" : mc2_mask ,
153- })
154170 if self .need_expert_scale :
155171 stage1_kwargs .update ({
156172 "expert_scales" :
@@ -256,7 +272,7 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
256272 "expert_shard_type" : 0 ,
257273 "shared_expert_rank_num" : 0 ,
258274 "moe_expert_num" : moe_expert_num ,
259- "global_bs" : 0 ,
275+ "global_bs" : self . global_bs ,
260276 }
261277
262278 if self .with_quant :
@@ -285,9 +301,6 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
285301 "tp_rank_id" : 0 ,
286302 })
287303
288- if self .a3_need_extra_args and self .enable_dispatch_v2 :
289- stage3_kwargs ["x_active_mask" ] = mc2_mask
290-
291304 kwargs_mc2 .update (stage3_kwargs )
292305 return kwargs_mc2
293306
0 commit comments