Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions collector/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import fcntl
import json
import logging
import math
import multiprocessing as mp
import os
import signal
Expand Down Expand Up @@ -264,3 +265,171 @@ def log_perf(
f.write(header_prefix + "\n")

f.write(content_prefix + "\n")


# Helper functions for MoE
def balanced_logits(num_tokens, num_experts, topk):
import torch
import torch.nn.functional as F

# h_selected_experts = -torch.ones([num_tokens, topk]).to(torch.device(device))
h_selected_experts = -torch.ones([num_tokens, topk])
stride = math.ceil(num_experts / topk)

for token_i in range(num_tokens):
for i in range(topk):
if num_tokens >= stride:
h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts
else:
h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts

expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
router_logits = F.softmax(expert_map.bfloat16(), dim=1)
return router_logits


def sample_power_law(size, alpha, xmin, xmax):
import torch

u = torch.rand(size)
inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha))
return inv_cdf


def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha):
import torch
import torch.nn.functional as F

if num_tokens * topk > num_experts:
num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8)
else:
num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2)

target_sum = num_tokens * topk

original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum()

target_distribution = original_distribution * target_sum

num_tokens_per_expert = torch.round(target_distribution).to(torch.int64)

current_sum = num_tokens_per_expert.sum().item()
delta = target_sum - current_sum
if delta != 0:
sorted_indices = torch.argsort(num_tokens_per_expert, descending=True)

if delta > 0:
for i in range(delta):
expert_idx = sorted_indices[i % len(sorted_indices)]
num_tokens_per_expert[expert_idx] += 1
else:
for i in range(-delta):
expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1]
if num_tokens_per_expert[expert_idx] > 0:
num_tokens_per_expert[expert_idx] -= 1
else:
num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1

if len(num_tokens_per_expert) > 1:
sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0]
assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted"

with torch.no_grad():
conv1d = torch.nn.Conv1d(
in_channels=1,
out_channels=1,
kernel_size=num_experts // ep,
stride=num_experts // ep,
padding=0,
bias=False,
)
conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)])
conv1d.weight.copy_(conv1d_weights)

res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float())
max_ep_idx = torch.argmax(res).item()

if max_ep_idx != 0:
ep_group_size = num_experts // ep
num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size)
num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = (
num_tokens_per_expert_reshaped[max_ep_idx].clone(),
num_tokens_per_expert_reshaped[0].clone(),
)
num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1)

aic_debug = int(os.getenv("AIC_DEBUG", "0"))
if aic_debug == 1:
print("num_tokens_per_expert", num_tokens_per_expert, num_tokens_per_expert.sum().item())

_, num_tokens_per_expert_sorted_index = torch.sort(num_tokens_per_expert, descending=True)
expert_assignments = []
num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index.tolist()
for expert_id in num_tokens_per_expert_sorted_index_lists:
expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id])

expert_assignments = torch.tensor(expert_assignments, dtype=torch.long)
h_selected_experts = expert_assignments.reshape(topk, num_tokens).T

expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
router_logits = F.softmax(expert_map.bfloat16(), dim=1)
return router_logits


# NOTE: power_law_logits_v4 was copied from power_law_logits_v3 and
# modified to restrict max tokens per expert to be less than num_tokens
def power_law_logits_v4(num_tokens, num_experts, topk, ep, alpha):
import torch

"""Generate power law distribution for token assignment to experts"""
while True:
if num_tokens * topk > num_experts:
num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8)
else:
num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2)
target_sum = num_tokens * topk

original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum()

target_distribution = original_distribution * target_sum

num_tokens_per_expert = torch.round(target_distribution).to(torch.int64)

current_sum = num_tokens_per_expert.sum().item()
delta = target_sum - current_sum
if delta != 0:
sorted_indices = torch.argsort(num_tokens_per_expert, descending=True)

if delta > 0:
for i in range(delta):
expert_idx = sorted_indices[i % len(sorted_indices)]
num_tokens_per_expert[expert_idx] += 1
else:
for i in range(-delta):
expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1]
if num_tokens_per_expert[expert_idx] > 0:
num_tokens_per_expert[expert_idx] -= 1
else:
num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1

if len(num_tokens_per_expert) > 1:
sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0]
assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted"

with torch.no_grad():
conv1d = torch.nn.Conv1d(
in_channels=1,
out_channels=1,
kernel_size=num_experts // ep,
stride=num_experts // ep,
padding=0,
bias=False,
)
conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)])
conv1d.weight.copy_(conv1d_weights)

res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float())
max_ep_idx = torch.argmax(res).item()
num_tokens_per_expert_rank0 = num_tokens_per_expert.view(ep, num_experts // ep)[max_ep_idx].view(-1)
if max(num_tokens_per_expert_rank0) <= num_tokens:
return num_tokens_per_expert_rank0
111 changes: 9 additions & 102 deletions collector/sglang/collect_moe.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import math
import os
from typing import TypedDict

import pkg_resources
import torch
import torch.nn.functional as F
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe,
get_config_dtype_str,
Expand All @@ -16,7 +14,15 @@
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip

from helper import log_perf
try:
from helper import balanced_logits, log_perf, power_law_logits_v3
except ModuleNotFoundError:
import os
import sys

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from helper import balanced_logits, log_perf, power_law_logits_v3


_is_hip = is_hip()

Expand Down Expand Up @@ -148,105 +154,6 @@ def get_moe_test_cases():
return test_cases


def balanced_logits(num_tokens, num_experts, topk):
h_selected_experts = -torch.ones([num_tokens, topk])
stride = math.ceil(num_experts / topk)

for token_i in range(num_tokens):
for i in range(topk):
if num_tokens >= stride:
h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts
else:
h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts

expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
router_logits = F.softmax(expert_map.bfloat16(), dim=1)
return router_logits


def sample_power_law(size, alpha, xmin, xmax):
u = torch.rand(size)
inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha))
return inv_cdf


def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha):
if num_tokens * topk > num_experts:
num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8)
else:
num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2)

target_sum = num_tokens * topk

original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum()

target_distribution = original_distribution * target_sum

num_tokens_per_expert = torch.round(target_distribution).to(torch.int64)

current_sum = num_tokens_per_expert.sum().item()
delta = target_sum - current_sum
if delta != 0:
sorted_indices = torch.argsort(num_tokens_per_expert, descending=True)

if delta > 0:
for i in range(delta):
expert_idx = sorted_indices[i % len(sorted_indices)]
num_tokens_per_expert[expert_idx] += 1
else:
for i in range(-delta):
expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1]
if num_tokens_per_expert[expert_idx] > 0:
num_tokens_per_expert[expert_idx] -= 1
else:
num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1

if len(num_tokens_per_expert) > 1:
sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0]
assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted"

with torch.no_grad():
conv1d = torch.nn.Conv1d(
in_channels=1,
out_channels=1,
kernel_size=num_experts // ep,
stride=num_experts // ep,
padding=0,
bias=False,
)
conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)])
conv1d.weight.copy_(conv1d_weights)

res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float())
max_ep_idx = torch.argmax(res).item()

if max_ep_idx != 0:
ep_group_size = num_experts // ep
num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size)
num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = (
num_tokens_per_expert_reshaped[max_ep_idx].clone(),
num_tokens_per_expert_reshaped[0].clone(),
)
num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1)

pet_debug = int(os.getenv("PET_DEBUG", "0"))
if pet_debug == 1:
print("num_tokens_per_expert", num_tokens_per_expert, num_tokens_per_expert.sum().item())

_, num_tokens_per_expert_sorted_index = torch.sort(num_tokens_per_expert, descending=True)
expert_assignments = []
num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index.tolist()
for expert_id in num_tokens_per_expert_sorted_index_lists:
expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id])

expert_assignments = torch.tensor(expert_assignments, dtype=torch.long)
h_selected_experts = expert_assignments.reshape(topk, num_tokens).T

expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1)
router_logits = F.softmax(expert_map.bfloat16(), dim=1)
return router_logits


class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
Expand Down
Loading