diff --git a/collector/helper.py b/collector/helper.py index 3aa5bfde..4309843e 100644 --- a/collector/helper.py +++ b/collector/helper.py @@ -4,6 +4,7 @@ import fcntl import json import logging +import math import multiprocessing as mp import os import signal @@ -264,3 +265,171 @@ def log_perf( f.write(header_prefix + "\n") f.write(content_prefix + "\n") + + +# Helper functions for MoE +def balanced_logits(num_tokens, num_experts, topk): + import torch + import torch.nn.functional as F + + # h_selected_experts = -torch.ones([num_tokens, topk]).to(torch.device(device)) + h_selected_experts = -torch.ones([num_tokens, topk]) + stride = math.ceil(num_experts / topk) + + for token_i in range(num_tokens): + for i in range(topk): + if num_tokens >= stride: + h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts + else: + h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts + + expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) + router_logits = F.softmax(expert_map.bfloat16(), dim=1) + return router_logits + + +def sample_power_law(size, alpha, xmin, xmax): + import torch + + u = torch.rand(size) + inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha)) + return inv_cdf + + +def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha): + import torch + import torch.nn.functional as F + + if num_tokens * topk > num_experts: + num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8) + else: + num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2) + + target_sum = num_tokens * topk + + original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum() + + target_distribution = original_distribution * target_sum + + num_tokens_per_expert = torch.round(target_distribution).to(torch.int64) + + current_sum = num_tokens_per_expert.sum().item() + delta = target_sum - current_sum + if delta != 0: + sorted_indices = torch.argsort(num_tokens_per_expert, descending=True) + + if delta > 0: + for i in range(delta): + expert_idx = sorted_indices[i % len(sorted_indices)] + num_tokens_per_expert[expert_idx] += 1 + else: + for i in range(-delta): + expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1] + if num_tokens_per_expert[expert_idx] > 0: + num_tokens_per_expert[expert_idx] -= 1 + else: + num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1 + + if len(num_tokens_per_expert) > 1: + sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0] + assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted" + + with torch.no_grad(): + conv1d = torch.nn.Conv1d( + in_channels=1, + out_channels=1, + kernel_size=num_experts // ep, + stride=num_experts // ep, + padding=0, + bias=False, + ) + conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]) + conv1d.weight.copy_(conv1d_weights) + + res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float()) + max_ep_idx = torch.argmax(res).item() + + if max_ep_idx != 0: + ep_group_size = num_experts // ep + num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size) + num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = ( + num_tokens_per_expert_reshaped[max_ep_idx].clone(), + num_tokens_per_expert_reshaped[0].clone(), + ) + num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1) + + aic_debug = int(os.getenv("AIC_DEBUG", "0")) + if aic_debug == 1: + print("num_tokens_per_expert", num_tokens_per_expert, num_tokens_per_expert.sum().item()) + + _, num_tokens_per_expert_sorted_index = torch.sort(num_tokens_per_expert, descending=True) + expert_assignments = [] + num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index.tolist() + for expert_id in num_tokens_per_expert_sorted_index_lists: + expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id]) + + expert_assignments = torch.tensor(expert_assignments, dtype=torch.long) + h_selected_experts = expert_assignments.reshape(topk, num_tokens).T + + expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) + router_logits = F.softmax(expert_map.bfloat16(), dim=1) + return router_logits + + +# NOTE: power_law_logits_v4 was copied from power_law_logits_v3 and +# modified to restrict max tokens per expert to be less than num_tokens +def power_law_logits_v4(num_tokens, num_experts, topk, ep, alpha): + import torch + + """Generate power law distribution for token assignment to experts""" + while True: + if num_tokens * topk > num_experts: + num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8) + else: + num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2) + target_sum = num_tokens * topk + + original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum() + + target_distribution = original_distribution * target_sum + + num_tokens_per_expert = torch.round(target_distribution).to(torch.int64) + + current_sum = num_tokens_per_expert.sum().item() + delta = target_sum - current_sum + if delta != 0: + sorted_indices = torch.argsort(num_tokens_per_expert, descending=True) + + if delta > 0: + for i in range(delta): + expert_idx = sorted_indices[i % len(sorted_indices)] + num_tokens_per_expert[expert_idx] += 1 + else: + for i in range(-delta): + expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1] + if num_tokens_per_expert[expert_idx] > 0: + num_tokens_per_expert[expert_idx] -= 1 + else: + num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1 + + if len(num_tokens_per_expert) > 1: + sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0] + assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted" + + with torch.no_grad(): + conv1d = torch.nn.Conv1d( + in_channels=1, + out_channels=1, + kernel_size=num_experts // ep, + stride=num_experts // ep, + padding=0, + bias=False, + ) + conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]) + conv1d.weight.copy_(conv1d_weights) + + res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float()) + max_ep_idx = torch.argmax(res).item() + num_tokens_per_expert_rank0 = num_tokens_per_expert.view(ep, num_experts // ep)[max_ep_idx].view(-1) + if max(num_tokens_per_expert_rank0) <= num_tokens: + return num_tokens_per_expert_rank0 diff --git a/collector/sglang/collect_moe.py b/collector/sglang/collect_moe.py index 72c5e3a6..0199bfbb 100644 --- a/collector/sglang/collect_moe.py +++ b/collector/sglang/collect_moe.py @@ -1,12 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import math import os from typing import TypedDict import pkg_resources import torch -import torch.nn.functional as F from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_moe, get_config_dtype_str, @@ -16,7 +14,15 @@ from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.utils import is_hip -from helper import log_perf +try: + from helper import balanced_logits, log_perf, power_law_logits_v3 +except ModuleNotFoundError: + import os + import sys + + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from helper import balanced_logits, log_perf, power_law_logits_v3 + _is_hip = is_hip() @@ -148,105 +154,6 @@ def get_moe_test_cases(): return test_cases -def balanced_logits(num_tokens, num_experts, topk): - h_selected_experts = -torch.ones([num_tokens, topk]) - stride = math.ceil(num_experts / topk) - - for token_i in range(num_tokens): - for i in range(topk): - if num_tokens >= stride: - h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts - else: - h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts - - expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) - router_logits = F.softmax(expert_map.bfloat16(), dim=1) - return router_logits - - -def sample_power_law(size, alpha, xmin, xmax): - u = torch.rand(size) - inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha)) - return inv_cdf - - -def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha): - if num_tokens * topk > num_experts: - num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8) - else: - num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2) - - target_sum = num_tokens * topk - - original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum() - - target_distribution = original_distribution * target_sum - - num_tokens_per_expert = torch.round(target_distribution).to(torch.int64) - - current_sum = num_tokens_per_expert.sum().item() - delta = target_sum - current_sum - if delta != 0: - sorted_indices = torch.argsort(num_tokens_per_expert, descending=True) - - if delta > 0: - for i in range(delta): - expert_idx = sorted_indices[i % len(sorted_indices)] - num_tokens_per_expert[expert_idx] += 1 - else: - for i in range(-delta): - expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1] - if num_tokens_per_expert[expert_idx] > 0: - num_tokens_per_expert[expert_idx] -= 1 - else: - num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1 - - if len(num_tokens_per_expert) > 1: - sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0] - assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted" - - with torch.no_grad(): - conv1d = torch.nn.Conv1d( - in_channels=1, - out_channels=1, - kernel_size=num_experts // ep, - stride=num_experts // ep, - padding=0, - bias=False, - ) - conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]) - conv1d.weight.copy_(conv1d_weights) - - res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float()) - max_ep_idx = torch.argmax(res).item() - - if max_ep_idx != 0: - ep_group_size = num_experts // ep - num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size) - num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = ( - num_tokens_per_expert_reshaped[max_ep_idx].clone(), - num_tokens_per_expert_reshaped[0].clone(), - ) - num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1) - - pet_debug = int(os.getenv("PET_DEBUG", "0")) - if pet_debug == 1: - print("num_tokens_per_expert", num_tokens_per_expert, num_tokens_per_expert.sum().item()) - - _, num_tokens_per_expert_sorted_index = torch.sort(num_tokens_per_expert, descending=True) - expert_assignments = [] - num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index.tolist() - for expert_id in num_tokens_per_expert_sorted_index_lists: - expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id]) - - expert_assignments = torch.tensor(expert_assignments, dtype=torch.long) - h_selected_experts = expert_assignments.reshape(topk, num_tokens).T - - expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) - router_logits = F.softmax(expert_map.bfloat16(), dim=1) - return router_logits - - class BenchmarkConfig(TypedDict): BLOCK_SIZE_M: int BLOCK_SIZE_N: int diff --git a/collector/sglang/collect_wideep_deepep_moe.py b/collector/sglang/collect_wideep_deepep_moe.py index 7cecfb38..e5c71505 100755 --- a/collector/sglang/collect_wideep_deepep_moe.py +++ b/collector/sglang/collect_wideep_deepep_moe.py @@ -24,13 +24,13 @@ ) try: - from helper import log_perf + from helper import log_perf, power_law_logits_v4 except ModuleNotFoundError: import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - from helper import log_perf + from helper import log_perf, power_law_logits_v4 import pkg_resources DEEPSEEK_MODEL_PATH = os.environ.get("DEEPSEEK_MODEL_PATH", "/deepseek-v3") @@ -82,69 +82,6 @@ def get_moe_decode_test_cases(): return test_cases -def sample_power_law(size, alpha, xmin, xmax): - u = torch.rand(size) - inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha)) - return inv_cdf - - -# NOTE: power_law_logits_v4 was copied from aiconfigurator/collector/trtllm/collect_moe.py and -# modified to restrict max tokens per expert to be less than num_tokens -def power_law_logits_v4(num_tokens, num_experts, topk, ep, alpha): - """Generate power law distribution for token assignment to experts""" - while True: - if num_tokens * topk > num_experts: - num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8) - else: - num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2) - target_sum = num_tokens * topk - - original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum() - - target_distribution = original_distribution * target_sum - - num_tokens_per_expert = torch.round(target_distribution).to(torch.int64) - - current_sum = num_tokens_per_expert.sum().item() - delta = target_sum - current_sum - if delta != 0: - sorted_indices = torch.argsort(num_tokens_per_expert, descending=True) - - if delta > 0: - for i in range(delta): - expert_idx = sorted_indices[i % len(sorted_indices)] - num_tokens_per_expert[expert_idx] += 1 - else: - for i in range(-delta): - expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1] - if num_tokens_per_expert[expert_idx] > 0: - num_tokens_per_expert[expert_idx] -= 1 - else: - num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1 - - if len(num_tokens_per_expert) > 1: - sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0] - assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted" - - with torch.no_grad(): - conv1d = torch.nn.Conv1d( - in_channels=1, - out_channels=1, - kernel_size=num_experts // ep, - stride=num_experts // ep, - padding=0, - bias=False, - ) - conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]) - conv1d.weight.copy_(conv1d_weights) - - res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float()) - max_ep_idx = torch.argmax(res).item() - num_tokens_per_expert_rank0 = num_tokens_per_expert.view(ep, num_experts // ep)[max_ep_idx].view(-1) - if max(num_tokens_per_expert_rank0) <= num_tokens: - return num_tokens_per_expert_rank0 - - def load_model_with_dummy_weights(server_args, port_args, tp_rank): """Load model with dummy weights and limited layers for MoE testing""" suppress_other_loggers() diff --git a/collector/trtllm/collect_moe.py b/collector/trtllm/collect_moe.py index b337066e..66ee964d 100755 --- a/collector/trtllm/collect_moe.py +++ b/collector/trtllm/collect_moe.py @@ -3,12 +3,10 @@ import glob import json -import math import os import tensorrt_llm import torch -import torch.nn.functional as F from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate @@ -16,7 +14,14 @@ from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig -from helper import get_sm_version, log_perf +try: + from helper import balanced_logits, get_sm_version, log_perf, power_law_logits_v3 +except ModuleNotFoundError: + import os + import sys + + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from helper import balanced_logits, get_sm_version, log_perf, power_law_logits_v3 aic_debug = int(os.getenv("aic_moe_debug", "0")) # noqa: SIM112 @@ -55,104 +60,6 @@ def cleanup_empty_json_files(directory): print(f"Total deleted {deleted_count} invalid JSON files from {directory}") -def balanced_logits(num_tokens, num_experts, topk, device): - h_selected_experts = -torch.ones([num_tokens, topk]).to(torch.device(device)) - stride = math.ceil(num_experts / topk) - - for token_i in range(num_tokens): - for i in range(topk): - if num_tokens >= stride: - h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts - else: - h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts - - expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) - router_logits = F.softmax(expert_map.bfloat16(), dim=1) - return router_logits - - -def sample_power_law(size, alpha, xmin, xmax): - u = torch.rand(size) - inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha)) - return inv_cdf - - -def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha, device): - if num_tokens * topk > num_experts: - num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8).to(torch.device(device)) - else: - num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2).to(torch.device(device)) - - target_sum = num_tokens * topk - - original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum() - - target_distribution = original_distribution * target_sum - - num_tokens_per_expert = torch.round(target_distribution).to(torch.int64) - - current_sum = num_tokens_per_expert.sum().item() - delta = target_sum - current_sum - if delta != 0: - sorted_indices = torch.argsort(num_tokens_per_expert, descending=True) - - if delta > 0: - for i in range(delta): - expert_idx = sorted_indices[i % len(sorted_indices)] - num_tokens_per_expert[expert_idx] += 1 - else: - for i in range(-delta): - expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1] - if num_tokens_per_expert[expert_idx] > 0: - num_tokens_per_expert[expert_idx] -= 1 - else: - num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1 - - if len(num_tokens_per_expert) > 1: - sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0] - assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted" - - with torch.no_grad(): - conv1d = torch.nn.Conv1d( - in_channels=1, - out_channels=1, - kernel_size=num_experts // ep, - stride=num_experts // ep, - padding=0, - bias=False, - ).to(torch.device(device)) - conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]).to(torch.device(device)) - conv1d.weight.copy_(conv1d_weights) - - res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float()) - max_ep_idx = torch.argmax(res).item() - - if max_ep_idx != 0: - ep_group_size = num_experts // ep - num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size) - num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = ( - num_tokens_per_expert_reshaped[max_ep_idx].clone(), - num_tokens_per_expert_reshaped[0].clone(), - ) - num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1) - - if aic_debug == 2: - print("num_tokens_per_expert", num_tokens_per_expert, num_tokens_per_expert.sum().item()) - - _, num_tokens_per_expert_sorted_index = torch.sort(num_tokens_per_expert, descending=True) - expert_assignments = [] - num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index.tolist() - for expert_id in num_tokens_per_expert_sorted_index_lists: - expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id]) - - expert_assignments = torch.tensor(expert_assignments, dtype=torch.long).to(torch.device(device)) - h_selected_experts = expert_assignments.reshape(topk, num_tokens).T - - expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) - router_logits = F.softmax(expert_map.bfloat16(), dim=1) - return router_logits - - def get_moe_test_cases(): num_tokens = [ 1, @@ -320,6 +227,10 @@ def run_moe_torch( power_law_alpha=0.0, device="cuda:0", ): + device = torch.device(device) + torch.cuda.set_device(device) + torch.set_default_device(device) + # moe type support float16, fp8_qdq, fp8_block, w4a8, nvfp4(not implemented yet) dtype = torch.bfloat16 quant_group_size = 128 @@ -458,9 +369,7 @@ def fp32_to_mxfp4(tensor): hidden_states_max_tokens = torch.randn([num_tokens_lists[-1], hidden_size]).bfloat16().to(torch.device(device)) - logits_max_tokens = balanced_logits(num_tokens_lists[-1], num_experts, topk, torch.device(device)).to( - router_logits_dtype - ) + logits_max_tokens = balanced_logits(num_tokens_lists[-1], num_experts, topk).to(router_logits_dtype) # dty run torch.cuda.synchronize() diff --git a/collector/trtllm/collect_moe_pre_0_20.py b/collector/trtllm/collect_moe_pre_0_20.py index 35dd1bf7..d39e3d9e 100644 --- a/collector/trtllm/collect_moe_pre_0_20.py +++ b/collector/trtllm/collect_moe_pre_0_20.py @@ -1,34 +1,23 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import math import tensorrt_llm import torch -import torch.nn.functional as F from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.modules.fused_moe import FusedMoE, RenormalizeMoeRoutingMethod from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig from torch.nn.parameter import Parameter -from helper import get_sm_version, log_perf +try: + from helper import balanced_logits, get_sm_version, log_perf +except ModuleNotFoundError: + import os + import sys - -def balanced_logits(num_tokens, num_experts, topk): - h_selected_experts = -torch.ones([num_tokens, topk]) - stride = math.ceil(num_experts / topk) - - for token_i in range(num_tokens): - for i in range(topk): - if num_tokens >= stride: - h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts - else: - h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts - - expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) - router_logits = F.softmax(expert_map.bfloat16(), dim=1) - return router_logits + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from helper import balanced_logits, get_sm_version, log_perf def get_moe_test_cases(): diff --git a/collector/trtllm/collect_moe_pre_1_0.py b/collector/trtllm/collect_moe_pre_1_0.py index 42ef674b..de8ebe53 100644 --- a/collector/trtllm/collect_moe_pre_1_0.py +++ b/collector/trtllm/collect_moe_pre_1_0.py @@ -1,12 +1,10 @@ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import math import os import tensorrt_llm import torch -import torch.nn.functional as F from tensorrt_llm._torch.autotuner import AutoTuner, autotune from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_deepseekv3 import DeepseekV3Gate @@ -15,107 +13,16 @@ from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig from torch.nn.parameter import Parameter -from helper import get_sm_version, log_perf +try: + from helper import balanced_logits, get_sm_version, log_perf, power_law_logits_v3 +except ModuleNotFoundError: + import os + import sys -aic_debug = int(os.getenv("aic_moe_debug", "0")) # noqa: SIM112 - - -def balanced_logits(num_tokens, num_experts, topk): - h_selected_experts = -torch.ones([num_tokens, topk]) - stride = math.ceil(num_experts / topk) - - for token_i in range(num_tokens): - for i in range(topk): - if num_tokens >= stride: - h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts - else: - h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts - - expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) - router_logits = F.softmax(expert_map.bfloat16(), dim=1) - return router_logits - - -def sample_power_law(size, alpha, xmin, xmax): - u = torch.rand(size) - inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha)) - return inv_cdf - - -def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha): - if num_tokens * topk > num_experts: - num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8) - else: - num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2) - - target_sum = num_tokens * topk + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + from helper import balanced_logits, get_sm_version, log_perf, power_law_logits_v3 - original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum() - - target_distribution = original_distribution * target_sum - - num_tokens_per_expert = torch.round(target_distribution).to(torch.int64) - - current_sum = num_tokens_per_expert.sum().item() - delta = target_sum - current_sum - if delta != 0: - sorted_indices = torch.argsort(num_tokens_per_expert, descending=True) - - if delta > 0: - for i in range(delta): - expert_idx = sorted_indices[i % len(sorted_indices)] - num_tokens_per_expert[expert_idx] += 1 - else: - for i in range(-delta): - expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1] - if num_tokens_per_expert[expert_idx] > 0: - num_tokens_per_expert[expert_idx] -= 1 - else: - num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1 - - if len(num_tokens_per_expert) > 1: - sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0] - assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted" - - with torch.no_grad(): - conv1d = torch.nn.Conv1d( - in_channels=1, - out_channels=1, - kernel_size=num_experts // ep, - stride=num_experts // ep, - padding=0, - bias=False, - ) - conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]) - conv1d.weight.copy_(conv1d_weights) - - res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float()) - max_ep_idx = torch.argmax(res).item() - - if max_ep_idx != 0: - ep_group_size = num_experts // ep - num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size) - num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = ( - num_tokens_per_expert_reshaped[max_ep_idx].clone(), - num_tokens_per_expert_reshaped[0].clone(), - ) - num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1) - - if aic_debug == 2: - print("num_tokens_per_expert", num_tokens_per_expert, num_tokens_per_expert.sum().item()) - - _, num_tokens_per_expert_sorted_index = torch.sort(num_tokens_per_expert, descending=True) - expert_assignments = [] - num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index.tolist() - for expert_id in num_tokens_per_expert_sorted_index_lists: - expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id]) - - expert_assignments = torch.tensor(expert_assignments, dtype=torch.long) - h_selected_experts = expert_assignments.reshape(topk, num_tokens).T - - expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) - router_logits = F.softmax(expert_map.bfloat16(), dim=1) - return router_logits +aic_debug = int(os.getenv("aic_moe_debug", "0")) # noqa: SIM112 def get_moe_test_cases():