Skip to content

Commit ace1d42

Browse files
committed
replace all_reduce for kv_consumer and support different num_tokens among all ranks
Signed-off-by: linfeng-yuan <1102311262@qq.com>
1 parent ea54388 commit ace1d42

File tree

5 files changed

+42
-17
lines changed

5 files changed

+42
-17
lines changed

vllm_ascend/ops/fused_moe/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def apply(self,
126126
# to avoid accumulating too much tokens on a single rank.
127127
# currently it is only activated when doing profile runs.
128128
if enable_force_load_balance:
129-
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
129+
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
130+
topk_ids = torch.argsort(random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
130131

131132
moe_comm_method = get_forward_context().moe_comm_method
132133
return moe_comm_method.fused_experts(

vllm_ascend/ops/fused_moe/token_dispatcher.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import torch
2727
import torch_npu
28+
from vllm.config import get_current_vllm_config
2829
from vllm.distributed.parallel_state import get_ep_group
2930

3031
from vllm_ascend.distributed.parallel_state import get_mc2_group
@@ -100,15 +101,34 @@ def __init__(self, **kwargs):
100101
self.need_extra_args = (
101102
get_ascend_device_type() == AscendDeviceType._910_93)
102103

103-
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
104-
self.a3_need_extra_args = \
105-
get_ascend_device_type() == AscendDeviceType._910_93
106104
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
107105
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
108106
# improve communication performance.
109107
self.need_expert_scale = is_hierarchical_communication_enabled()
110108
self.with_quant = False
111109

110+
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
111+
# dispatch & combine operators with different input num_tokens per rank.
112+
vllm_config = get_current_vllm_config()
113+
scheduler_config = vllm_config.scheduler_config
114+
compilation_config = vllm_config.compilation_config
115+
speculative_config = vllm_config.speculative_config
116+
tp_size = vllm_config.parallel_config.tensor_parallel_size
117+
uniform_decode_query_len = 1 if not speculative_config else \
118+
1 + speculative_config.num_speculative_tokens
119+
decode_max_num_seqs = getattr(scheduler_config,
120+
'decode_max_num_seqs', 0)
121+
max_num_reqs = max(scheduler_config.max_num_seqs,
122+
decode_max_num_seqs)
123+
if compilation_config.cudagraph_capture_sizes:
124+
max_num_tokens = compilation_config.max_cudagraph_capture_size
125+
else:
126+
max_num_tokens = min(
127+
max_num_reqs * uniform_decode_query_len, 512)
128+
num_tokens_per_tp_rank = (max_num_tokens + tp_size - 1) // tp_size
129+
mc2_tokens_capacity = num_tokens_per_tp_rank * tp_size
130+
self.global_bs = mc2_tokens_capacity * self.ep_world_size
131+
112132
def get_dispatch_mc2_kwargs(
113133
self,
114134
hidden_states: torch.Tensor,
@@ -130,7 +150,7 @@ def get_dispatch_mc2_kwargs(
130150
"expert_shard_type": 0,
131151
"shared_expert_rank_num": 0,
132152
"moe_expert_num": moe_expert_num,
133-
"global_bs": 0,
153+
"global_bs": self.global_bs,
134154
"expert_token_nums_type": 0,
135155
}
136156

@@ -147,10 +167,6 @@ def get_dispatch_mc2_kwargs(
147167
"tp_world_size": 1,
148168
"tp_rank_id": 0,
149169
})
150-
if self.a3_need_extra_args and self.enable_dispatch_v2:
151-
stage1_kwargs.update({
152-
"x_active_mask": mc2_mask,
153-
})
154170
if self.need_expert_scale:
155171
stage1_kwargs.update({
156172
"expert_scales":
@@ -256,7 +272,7 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
256272
"expert_shard_type": 0,
257273
"shared_expert_rank_num": 0,
258274
"moe_expert_num": moe_expert_num,
259-
"global_bs": 0,
275+
"global_bs": self.global_bs,
260276
}
261277

262278
if self.with_quant:
@@ -285,9 +301,6 @@ def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
285301
"tp_rank_id": 0,
286302
})
287303

288-
if self.a3_need_extra_args and self.enable_dispatch_v2:
289-
stage3_kwargs["x_active_mask"] = mc2_mask
290-
291304
kwargs_mc2.update(stage3_kwargs)
292305
return kwargs_mc2
293306

vllm_ascend/quantization/w4a8_dynamic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -371,8 +371,8 @@ def apply(
371371
# to avoid accumulating too much tokens on a single rank.
372372
# currently it is only activated when doing profile runs.
373373
if enable_force_load_balance:
374-
topk_ids = torch.randint_like(
375-
topk_ids, 0, global_num_experts - global_redundant_expert_num)
374+
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
375+
topk_ids = torch.argsort(random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
376376

377377
topk_weights = topk_weights.to(x.dtype)
378378

vllm_ascend/quantization/w8a8_dynamic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def apply(
216216
# to avoid accumulating too much tokens on a single rank.
217217
# currently it is only activated when doing profile runs.
218218
if enable_force_load_balance:
219-
topk_ids = torch.randint_like(
220-
topk_ids, 0, global_num_experts - global_redundant_expert_num)
219+
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
220+
topk_ids = torch.argsort(random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
221221

222222
topk_weights = topk_weights.to(self.in_dtype)
223223

vllm_ascend/worker/model_runner_v1.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,17 @@ def _sync_metadata_across_dp(
910910
if self.dp_size == 1:
911911
return num_tokens, None, with_prefill
912912

913+
# NOTE: Here we can skip the all_reduce operation and avoid paading tokens
914+
# to max_tokens_acrodd_dp in D nodes. In MoE models, we must ensure that
915+
# num_tokens DOES NOT exceed mc2_tokens_capacity which means that moe_comm_method
916+
# of each rank is MC2. It is recommended to enable recompute scheduler for D Noes.
917+
if self.is_kv_consumer and not self.in_profile_run:
918+
num_tokens_after_padding = torch.tensor([num_tokens] *
919+
self.dp_size,
920+
device="cpu",
921+
dtype=torch.int32)
922+
return num_tokens, num_tokens_after_padding, with_prefill
923+
913924
# Sync num_tokens, with_prefill across dp ranks
914925
num_tokens_tensor = torch.tensor([
915926
num_tokens if i == self.dp_rank else 0 for i in range(self.dp_size)

0 commit comments

Comments
 (0)