Skip to content

Commit 49702f5

Browse files
committed
Factor out common moe functions in collector
Signed-off-by: Ilya Sherstyuk <[email protected]>
1 parent 5198119 commit 49702f5

File tree

6 files changed

+208
-390
lines changed

6 files changed

+208
-390
lines changed

collector/helper.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import fcntl
55
import json
66
import logging
7+
import math
78
import multiprocessing as mp
89
import os
910
import signal
@@ -264,3 +265,171 @@ def log_perf(
264265
f.write(header_prefix + "\n")
265266

266267
f.write(content_prefix + "\n")
268+
269+
270+
# Helper functions for MoE
271+
def balanced_logits(num_tokens, num_experts, topk):
272+
import torch
273+
import torch.nn.functional as F
274+
275+
# h_selected_experts = -torch.ones([num_tokens, topk]).to(torch.device(device))
276+
h_selected_experts = -torch.ones([num_tokens, topk])
277+
stride = math.ceil(num_experts / topk)
278+
279+
for token_i in range(num_tokens):
280+
for i in range(topk):
281+
if num_tokens >= stride:
282+
h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts
283+
else:
284+
h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts
285+
286+
expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
287+
router_logits = F.softmax(expert_map.bfloat16(), dim=1)
288+
return router_logits
289+
290+
291+
def sample_power_law(size, alpha, xmin, xmax):
292+
import torch
293+
294+
u = torch.rand(size)
295+
inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha))
296+
return inv_cdf
297+
298+
299+
def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha):
300+
import torch
301+
import torch.nn.functional as F
302+
303+
if num_tokens * topk > num_experts:
304+
num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8)
305+
else:
306+
num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2)
307+
308+
target_sum = num_tokens * topk
309+
310+
original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum()
311+
312+
target_distribution = original_distribution * target_sum
313+
314+
num_tokens_per_expert = torch.round(target_distribution).to(torch.int64)
315+
316+
current_sum = num_tokens_per_expert.sum().item()
317+
delta = target_sum - current_sum
318+
if delta != 0:
319+
sorted_indices = torch.argsort(num_tokens_per_expert, descending=True)
320+
321+
if delta > 0:
322+
for i in range(delta):
323+
expert_idx = sorted_indices[i % len(sorted_indices)]
324+
num_tokens_per_expert[expert_idx] += 1
325+
else:
326+
for i in range(-delta):
327+
expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1]
328+
if num_tokens_per_expert[expert_idx] > 0:
329+
num_tokens_per_expert[expert_idx] -= 1
330+
else:
331+
num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1
332+
333+
if len(num_tokens_per_expert) > 1:
334+
sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0]
335+
assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted"
336+
337+
with torch.no_grad():
338+
conv1d = torch.nn.Conv1d(
339+
in_channels=1,
340+
out_channels=1,
341+
kernel_size=num_experts // ep,
342+
stride=num_experts // ep,
343+
padding=0,
344+
bias=False,
345+
)
346+
conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)])
347+
conv1d.weight.copy_(conv1d_weights)
348+
349+
res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float())
350+
max_ep_idx = torch.argmax(res).item()
351+
352+
if max_ep_idx != 0:
353+
ep_group_size = num_experts // ep
354+
num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size)
355+
num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = (
356+
num_tokens_per_expert_reshaped[max_ep_idx].clone(),
357+
num_tokens_per_expert_reshaped[0].clone(),
358+
)
359+
num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1)
360+
361+
aic_debug = int(os.getenv("AIC_DEBUG", "0"))
362+
if aic_debug == 1:
363+
print("num_tokens_per_expert", num_tokens_per_expert, num_tokens_per_expert.sum().item())
364+
365+
_, num_tokens_per_expert_sorted_index = torch.sort(num_tokens_per_expert, descending=True)
366+
expert_assignments = []
367+
num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index.tolist()
368+
for expert_id in num_tokens_per_expert_sorted_index_lists:
369+
expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id])
370+
371+
expert_assignments = torch.tensor(expert_assignments, dtype=torch.long)
372+
h_selected_experts = expert_assignments.reshape(topk, num_tokens).T
373+
374+
expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
375+
router_logits = F.softmax(expert_map.bfloat16(), dim=1)
376+
return router_logits
377+
378+
379+
# NOTE: power_law_logits_v4 was copied from power_law_logits_v3 and
380+
# modified to restrict max tokens per expert to be less than num_tokens
381+
def power_law_logits_v4(num_tokens, num_experts, topk, ep, alpha):
382+
import torch
383+
384+
"""Generate power law distribution for token assignment to experts"""
385+
while True:
386+
if num_tokens * topk > num_experts:
387+
num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8)
388+
else:
389+
num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2)
390+
target_sum = num_tokens * topk
391+
392+
original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum()
393+
394+
target_distribution = original_distribution * target_sum
395+
396+
num_tokens_per_expert = torch.round(target_distribution).to(torch.int64)
397+
398+
current_sum = num_tokens_per_expert.sum().item()
399+
delta = target_sum - current_sum
400+
if delta != 0:
401+
sorted_indices = torch.argsort(num_tokens_per_expert, descending=True)
402+
403+
if delta > 0:
404+
for i in range(delta):
405+
expert_idx = sorted_indices[i % len(sorted_indices)]
406+
num_tokens_per_expert[expert_idx] += 1
407+
else:
408+
for i in range(-delta):
409+
expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1]
410+
if num_tokens_per_expert[expert_idx] > 0:
411+
num_tokens_per_expert[expert_idx] -= 1
412+
else:
413+
num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1
414+
415+
if len(num_tokens_per_expert) > 1:
416+
sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0]
417+
assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted"
418+
419+
with torch.no_grad():
420+
conv1d = torch.nn.Conv1d(
421+
in_channels=1,
422+
out_channels=1,
423+
kernel_size=num_experts // ep,
424+
stride=num_experts // ep,
425+
padding=0,
426+
bias=False,
427+
)
428+
conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)])
429+
conv1d.weight.copy_(conv1d_weights)
430+
431+
res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float())
432+
max_ep_idx = torch.argmax(res).item()
433+
num_tokens_per_expert_rank0 = num_tokens_per_expert.view(ep, num_experts // ep)[max_ep_idx].view(-1)
434+
if max(num_tokens_per_expert_rank0) <= num_tokens:
435+
return num_tokens_per_expert_rank0

collector/sglang/collect_moe.py

Lines changed: 9 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
import math
43
import os
54
from typing import TypedDict
65

76
import pkg_resources
87
import torch
9-
import torch.nn.functional as F
108
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
119
fused_moe,
1210
get_config_dtype_str,
@@ -16,7 +14,15 @@
1614
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
1715
from sglang.srt.utils import is_hip
1816

19-
from helper import log_perf
17+
try:
18+
from helper import balanced_logits, log_perf, power_law_logits_v3
19+
except ModuleNotFoundError:
20+
import os
21+
import sys
22+
23+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24+
from helper import balanced_logits, log_perf, power_law_logits_v3
25+
2026

2127
_is_hip = is_hip()
2228

@@ -148,105 +154,6 @@ def get_moe_test_cases():
148154
return test_cases
149155

150156

151-
def balanced_logits(num_tokens, num_experts, topk):
152-
h_selected_experts = -torch.ones([num_tokens, topk])
153-
stride = math.ceil(num_experts / topk)
154-
155-
for token_i in range(num_tokens):
156-
for i in range(topk):
157-
if num_tokens >= stride:
158-
h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts
159-
else:
160-
h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts
161-
162-
expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
163-
router_logits = F.softmax(expert_map.bfloat16(), dim=1)
164-
return router_logits
165-
166-
167-
def sample_power_law(size, alpha, xmin, xmax):
168-
u = torch.rand(size)
169-
inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha))
170-
return inv_cdf
171-
172-
173-
def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha):
174-
if num_tokens * topk > num_experts:
175-
num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8)
176-
else:
177-
num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2)
178-
179-
target_sum = num_tokens * topk
180-
181-
original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum()
182-
183-
target_distribution = original_distribution * target_sum
184-
185-
num_tokens_per_expert = torch.round(target_distribution).to(torch.int64)
186-
187-
current_sum = num_tokens_per_expert.sum().item()
188-
delta = target_sum - current_sum
189-
if delta != 0:
190-
sorted_indices = torch.argsort(num_tokens_per_expert, descending=True)
191-
192-
if delta > 0:
193-
for i in range(delta):
194-
expert_idx = sorted_indices[i % len(sorted_indices)]
195-
num_tokens_per_expert[expert_idx] += 1
196-
else:
197-
for i in range(-delta):
198-
expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1]
199-
if num_tokens_per_expert[expert_idx] > 0:
200-
num_tokens_per_expert[expert_idx] -= 1
201-
else:
202-
num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1
203-
204-
if len(num_tokens_per_expert) > 1:
205-
sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0]
206-
assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted"
207-
208-
with torch.no_grad():
209-
conv1d = torch.nn.Conv1d(
210-
in_channels=1,
211-
out_channels=1,
212-
kernel_size=num_experts // ep,
213-
stride=num_experts // ep,
214-
padding=0,
215-
bias=False,
216-
)
217-
conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)])
218-
conv1d.weight.copy_(conv1d_weights)
219-
220-
res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float())
221-
max_ep_idx = torch.argmax(res).item()
222-
223-
if max_ep_idx != 0:
224-
ep_group_size = num_experts // ep
225-
num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size)
226-
num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = (
227-
num_tokens_per_expert_reshaped[max_ep_idx].clone(),
228-
num_tokens_per_expert_reshaped[0].clone(),
229-
)
230-
num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1)
231-
232-
pet_debug = int(os.getenv("PET_DEBUG", "0"))
233-
if pet_debug == 1:
234-
print("num_tokens_per_expert", num_tokens_per_expert, num_tokens_per_expert.sum().item())
235-
236-
_, num_tokens_per_expert_sorted_index = torch.sort(num_tokens_per_expert, descending=True)
237-
expert_assignments = []
238-
num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index.tolist()
239-
for expert_id in num_tokens_per_expert_sorted_index_lists:
240-
expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id])
241-
242-
expert_assignments = torch.tensor(expert_assignments, dtype=torch.long)
243-
h_selected_experts = expert_assignments.reshape(topk, num_tokens).T
244-
245-
expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
246-
router_logits = F.softmax(expert_map.bfloat16(), dim=1)
247-
return router_logits
248-
249-
250157
class BenchmarkConfig(TypedDict):
251158
BLOCK_SIZE_M: int
252159
BLOCK_SIZE_N: int

0 commit comments

Comments
 (0)