diff --git a/collector/collect.py b/collector/collect.py index f31740bd..9ac6682c 100644 --- a/collector/collect.py +++ b/collector/collect.py @@ -377,7 +377,7 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): 'module': 'trtllm.collect_mla', 'get_func': 'get_context_mla_test_cases', 'run_func': 'run_mla', - 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') + 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1', '1.2') else 'trtllm.collect_mla' }, { @@ -386,7 +386,7 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): 'module': 'trtllm.collect_mla', 'get_func': 'get_generation_mla_test_cases', 'run_func': 'run_mla', - 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1') + 'version_handler': lambda v: 'trtllm.collect_mla_1_1rc2' if v.startswith('1.1', '1.2') else 'trtllm.collect_mla' }, @@ -431,9 +431,41 @@ def collect_trtllm(num_processes: int, ops: List[str]=None): 'run_func': 'run_moe_torch', 'version_handler': lambda v: 'trtllm.collect_moe_pre_0_20' if v.startswith('0.20.0') else 'trtllm.collect_moe_pre_1_0' if v.startswith(('0.21.0', '1.0.0')) - else 'trtllm.collect_moe' if v.startswith(('1.1.0')) + else 'trtllm.collect_moe' if v.startswith(('1.1.', '1.2.')) else None - } + }, + + # CONV 1D collections + { + 'name': 'trtllm', + 'type': 'conv1d_fn', + 'module': 'trtllm.collect_conv1d', + 'get_func': 'get_conv1d_fn_test_cases', + 'run_func': 'run_conv1d_fn' + }, + { + 'name': 'trtllm', + 'type': 'conv1d_update', + 'module': 'trtllm.collect_conv1d', + 'get_func': 'get_conv1d_update_test_cases', + 'run_func': 'run_conv1d_update' + }, + + # Gated Delta Rule collections + { + 'name': 'trtllm', + 'type': 'chunk_gated_delta_rule', + 'module': 'trtllm.collect_gated_delta_rule', + 'get_func': 'get_chunk_gated_delta_rule_test_cases', + 'run_func': 'run_chunk_gated_delta_rule' + }, + { + 'name': 'trtllm', + 'type': 'gated_delta_rule_update', + 'module': 'trtllm.collect_gated_delta_rule', + 'get_func': 'get_gated_delta_rule_update_test_cases', + 'run_func': 'run_gated_delta_rule_update' + }, ] for collection in collections: @@ -529,7 +561,8 @@ def main(): parser.add_argument('--ops', nargs='*', type=str, choices=['gemm_trt', 'gemm', 'mla_context', 'mla_generation', 'attention_context', 'attention_generation', 'mla_bmm_gen_pre', - 'mla_bmm_gen_post', 'moe'], + 'mla_bmm_gen_post', 'moe', 'conv1d_fn', 'conv1d_update', + 'chunk_gated_delta_rule', 'gated_delta_rule_update'], help='Run only specified collection items. Leave empty to run all.', default=None) args = parser.parse_args() diff --git a/collector/sglang/collect_moe.py b/collector/sglang/collect_moe.py index 33dfaa4c..f445828f 100644 --- a/collector/sglang/collect_moe.py +++ b/collector/sglang/collect_moe.py @@ -41,13 +41,15 @@ def get_moe_test_cases(): #[2048,1408,4,60], #qwen1.5_moe #[2048,1408,6,64], #deepseekv1_moe #[5120,1536,6,160], #deepseekv2 - model_config_list=[[4096,14336,2,8,'MOE_Mixtral8x7B'],# mixtral_8x7b - [6144,16384,2,8,'MOE_Mixtral8x22B'],# mixtral_8x22b - [7168,2048,8,256,'DEEPSEEK_V3'], # deepseekv3, will have 1 shared expert - [4096,1536,8,128, 'QWEN3_235B'], # qwen3-moe, 235b-a22b - [6144,2560,8,160, 'QWEN3_480B'], # qwen3-moe, 480b-a35b - [7168,2048,8,384, 'KIMI_K2'], # kimi k2 - ] + model_config_list=[ + # [4096,14336,2,8,'MOE_Mixtral8x7B'],# mixtral_8x7b + # [6144,16384,2,8,'MOE_Mixtral8x22B'],# mixtral_8x22b + # [7168,2048,8,256,'DEEPSEEK_V3'], # deepseekv3, will have 1 shared expert + # [4096,1536,8,128, 'QWEN3_235B'], # qwen3-moe, 235b-a22b + # [6144,2560,8,160, 'QWEN3_480B'], # qwen3-moe, 480b-a35b + # [7168,2048,8,384, 'KIMI_K2'], # kimi k2 + [2048,5120,50,512, 'QWEN3_NEXT_80B'], # qwen3-next, 80b-a3b + ] moe_list=['float16', 'fp8_block'] test_cases=[] diff --git a/collector/trtllm/collect_conv1d.py b/collector/trtllm/collect_conv1d.py new file mode 100644 index 00000000..32a936a5 --- /dev/null +++ b/collector/trtllm/collect_conv1d.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import os +from cuda import cuda +import torch +import tensorrt_llm +from tensorrt_llm._torch.modules.mamba.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from helper import log_perf + +def get_conv1d_fn_test_cases(): + """ + Generate test cases for Conv1DFn operations. + + Test parameters: + - batch_size: batch size + - isl: sequence length + - conv_kernel_size: size of the convolution kernel + - conv_dim: dimension of the convolution + - tp_size: attention tensor parallel size + """ + b_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + s_list = [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072] + tp_sizes = [1, 2, 4, 8] + conv_dims = [64, 128, 256, 512, 768, 1024, 1536, 2048, 3072, 4096] + kernel_sizes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + + test_cases = [] + for batch_size in b_list: + for isl in s_list: + for tp_size in tp_sizes: + for conv_dim in conv_dims: + for kernel_size in kernel_sizes: + test_cases.append([batch_size, isl, kernel_size, conv_dim, tp_size, 'conv1d_fn_perf.txt']) + + return test_cases + + +def run_conv1d_fn(batch_size, isl, conv_kernel_size, conv_dim, tp_size, perf_filename, device='cuda:0'): + """ + Run Conv1DFn performance benchmarking. + + Args: + batch_size: Batch size + isl: Sequence length + conv_kernel_size: Size of the convolution kernel + conv_dim: Dimension of the convolution + tp_size: Attention tensor parallel size + perf_filename: Output file for performance results + device: CUDA device to use + """ + dtype = torch.bfloat16 + # Create input with proper 3D shape: (batch_size, dim, seqlen) + mixed_qkv = torch.randn((batch_size, conv_dim // tp_size, isl), dtype=dtype, device=device) + conv1d_weights = torch.randn((conv_dim // tp_size, conv_kernel_size), dtype=dtype, device=device) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + # TODO: measure optional arguments + causal_conv1d_fn( + mixed_qkv, + conv1d_weights, + ) + + num_warmups = 3 + num_runs = 6 + + # warmup + for _ in range(num_warmups): + g.replay() + + # measure + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(num_runs): + g.replay() + end_event.record() + torch.cuda.synchronize() + latency = start_event.elapsed_time(end_event)/num_runs + + log_perf( + item_list=[{ + 'batch_size': batch_size, + 'isl': isl, + 'conv_kernel_size': conv_kernel_size, + 'conv_dim': conv_dim, + 'tp_size': tp_size, + 'latency': latency + }], + framework='TRTLLM', + version=tensorrt_llm.__version__, + device_name=torch.cuda.get_device_name(device), + op_name='conv1d_fn', + kernel_source='default', + perf_filename=perf_filename + ) + +def get_conv1d_update_test_cases(): + """ + Generate test cases for Conv1DUpdate operations. + + Test parameters: + - batch_size: batch size + - conv_kernel_size: size of the convolution kernel (must be between 2 and 4) + - conv_dim: dimension of the convolution + - tp_size: attention tensor parallel size + + Note: isl (sequence length) is not used for conv1d_update as it processes + individual tokens in incremental/streaming inference mode. + Note: causal_conv1d_update only supports kernel widths between 2 and 4. + """ + b_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + tp_sizes = [1, 2, 4, 8] + conv_dims = [1,2,4,8,16,32] + kernel_sizes = [2, 3, 4] # causal_conv1d_update only supports widths 2-4 + + test_cases = [] + for batch_size in b_list: + for tp_size in tp_sizes: + for conv_dim in conv_dims: + for kernel_size in kernel_sizes: + test_cases.append([batch_size, kernel_size, conv_dim, tp_size, 'conv1d_update_perf.txt']) + + return test_cases + + +def run_conv1d_update(batch_size, conv_kernel_size, conv_dim, tp_size, perf_filename, device='cuda:0'): + """ + Run Conv1DUpdate performance benchmarking. + + Args: + batch_size: Batch size + conv_kernel_size: Size of the convolution kernel + conv_dim: Dimension of the convolution + tp_size: Attention tensor parallel size + perf_filename: Output file for performance results + device: CUDA device to use + + Note: isl (sequence length) is not used as conv1d_update processes individual + tokens in incremental/streaming inference mode. + """ + dtype = torch.bfloat16 + # Create input with shape (batch_size, dim) + mixed_qkv = torch.randn((batch_size, conv_dim // tp_size), dtype=dtype, device=device) + # Create conv_state with shape (batch_size, dim, kernel_size - 1) + conv_state = torch.randn((batch_size, conv_dim // tp_size, conv_kernel_size - 1), dtype=dtype, device=device) + conv1d_weights = torch.randn((conv_dim // tp_size, conv_kernel_size), dtype=dtype, device=device) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + # TODO: measure optional arguments + causal_conv1d_update( + mixed_qkv, + conv_state, + conv1d_weights, + ) + + num_warmups = 3 + num_runs = 6 + + # warmup + for _ in range(num_warmups): + g.replay() + + # measure + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(num_runs): + g.replay() + end_event.record() + torch.cuda.synchronize() + latency = start_event.elapsed_time(end_event)/num_runs + + log_perf( + item_list=[{ + 'batch_size': batch_size, + 'conv_kernel_size': conv_kernel_size, + 'conv_dim': conv_dim, + 'tp_size': tp_size, + 'latency': latency + }], + framework='TRTLLM', + version=tensorrt_llm.__version__, + device_name=torch.cuda.get_device_name(device), + op_name='conv1d_update', + kernel_source='default', + perf_filename=perf_filename + ) diff --git a/collector/trtllm/collect_gated_delta_rule.py b/collector/trtllm/collect_gated_delta_rule.py new file mode 100644 index 00000000..2315e909 --- /dev/null +++ b/collector/trtllm/collect_gated_delta_rule.py @@ -0,0 +1,223 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import os +from cuda import cuda +import torch +import tensorrt_llm +from tensorrt_llm._torch.modules.fla.chunk import chunk_gated_delta_rule +from tensorrt_llm._torch.modules.fla.fused_sigmoid_gating_recurrent import fused_sigmoid_gating_delta_rule_update +from helper import log_perf + +def get_chunk_gated_delta_rule_test_cases(): + """ + Generate test cases for chunk_gated_delta_rule() operations. + + Test parameters: + - num_heads: number of heads + - head_k_dim: dimension of the key heads + - head_v_dim: dimension of the value heads + - num_value_heads: number of value heads + - isl: sequence length + """ + num_heads_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + head_k_dim_list = [1,2,4,8,16,32,64,128] + head_v_dim_list = [1,2,4,8,16,32,64,128] + num_value_heads_list = [1,2,4,8,16,32,64,128] + isl_list = [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072] + + test_cases = [] + for num_heads in num_heads_list: + for head_k_dim in head_k_dim_list: + for head_v_dim in head_v_dim_list: + for num_value_heads in num_value_heads_list: + # Skip invalid combinations: num_heads must be >= num_value_heads and divisible by it + # This constraint is typical for Grouped-Query Attention (GQA) + if num_heads < num_value_heads or num_heads % num_value_heads != 0: + continue + for isl in isl_list: + test_cases.append([num_heads, head_k_dim, head_v_dim, num_value_heads, isl, 'chunk_gated_delta_rule_perf.txt']) + + return test_cases + + +def run_chunk_gated_delta_rule(num_heads, head_k_dim, head_v_dim, num_value_heads, isl, perf_filename, device='cuda:0'): + """ + Run chunk_gated_delta_rule() performance benchmarking. + + Args: + num_heads: Number of heads + head_k_dim: Dimension of the key heads + head_v_dim: Dimension of the value heads + num_value_heads: Number of value heads + isl: Sequence length + perf_filename: Output file for performance results + device: CUDA device to use + """ + # NOTICE: ignored fused_gdn_gating operation + dtype = torch.bfloat16 + q = torch.randn((1, isl, num_heads, head_k_dim), dtype=dtype).to(torch.device(device)) + k = torch.randn((1, isl, num_heads, head_k_dim), dtype=dtype).to(torch.device(device)) + v = torch.randn((1, isl, num_value_heads, head_v_dim), dtype=dtype).to(torch.device(device)) + gate = torch.randn((1, isl, num_heads), dtype=dtype).to(torch.device(device)) + beta = torch.randn((1, isl, num_heads), dtype=dtype).to(torch.device(device)) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + chunk_gated_delta_rule(q, k, v, gate, beta) + + num_warmups = 3 + num_runs = 6 + + # warmup + for _ in range(num_warmups): + g.replay() + + # measure + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(num_runs): + g.replay() + end_event.record() + torch.cuda.synchronize() + latency = start_event.elapsed_time(end_event)/num_runs + + log_perf( + item_list=[{ + 'num_heads': num_heads, + 'head_k_dim': head_k_dim, + 'head_v_dim': head_v_dim, + 'num_value_heads': num_value_heads, + 'isl': isl, + 'latency': latency + }], + framework='TRTLLM', + version=tensorrt_llm.__version__, + device_name=torch.cuda.get_device_name(device), + op_name='chunk_gated_delta_rule', + kernel_source='default', + perf_filename=perf_filename + ) + +def get_gated_delta_rule_update_test_cases(): + """ + Generate test cases for Conv1DUpdate operations. + + Test parameters: + - batch_size: batch size + - isl: sequence length + - num_heads: number of heads + - head_k_dim: dimension of the key heads + - head_v_dim: dimension of the value heads + - num_value_heads: number of value heads + - max_batch_size: maximum batch size + """ + b_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + s_list = [2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072] + num_heads_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + head_k_dim_list = [1,2,4,8,16,32,64,128] + head_v_dim_list = [1,2,4,8,16,32,64,128] + num_value_heads_list = [1,2,4,8,16,32,64,128] + max_batch_size_list = [1,2,4,8,16,32,64,128,256,512,1024,2048] + + test_cases = [] + for batch_size in b_list: + for isl in s_list: + for num_heads in num_heads_list: + for head_k_dim in head_k_dim_list: + for head_v_dim in head_v_dim_list: + for num_value_heads in num_value_heads_list: + # Skip invalid combinations: num_heads must be >= num_value_heads and divisible by it + # This constraint is typical for Grouped-Query Attention (GQA) + if num_heads < num_value_heads or num_heads % num_value_heads != 0: + continue + for max_batch_size in max_batch_size_list: + # max_batch_size must be >= batch_size + if max_batch_size < batch_size: + continue + test_cases.append([batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size, 'gated_delta_rule_update_perf.txt']) + + return test_cases + + +def run_gated_delta_rule_update(batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size, perf_filename, device='cuda:0'): + """ + Run fused_sigmoid_gating_delta_rule_update() performance benchmarking. + + Args: + batch_size: Batch size + isl: Sequence length + conv_kernel_size: Size of the convolution kernel + conv_dim: Dimension of the convolution + tp_size: Attention tensor parallel size + perf_filename: Output file for performance results + device: CUDA device to use + """ + dtype = torch.bfloat16 + A_log = torch.randn((num_heads * num_value_heads), dtype=dtype).to(torch.device(device)) + dt_bias = torch.randn((num_heads * num_value_heads), dtype=dtype).to(torch.device(device)) + q = torch.randn((batch_size, isl, num_heads, head_k_dim), dtype=dtype).to(torch.device(device)) + k = torch.randn((batch_size, isl, num_heads, head_k_dim), dtype=dtype).to(torch.device(device)) + v = torch.randn((batch_size, isl, num_value_heads, head_v_dim), dtype=dtype).to(torch.device(device)) + a = torch.randn((batch_size * isl, num_heads * num_value_heads), dtype=dtype).to(torch.device(device)) + b = torch.randn((batch_size, isl, num_heads * num_value_heads), dtype=dtype).to(torch.device(device)) + initial_state_source = torch.randn((max_batch_size, num_heads * num_value_heads, head_k_dim, head_v_dim), dtype=dtype).to(torch.device(device)) + # initial_state_indices should be integers, not floats - they index into initial_state_source + initial_state_indices = torch.randint(0, max_batch_size, (batch_size,), dtype=torch.int32, device=device) + softplus_beta = 1.0 + softplus_threshold = 20.0 + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + # TODO: measure optional arguments + fused_sigmoid_gating_delta_rule_update( + A_log =A_log, + dt_bias = dt_bias, + q = q, + k = k, + v = v, + a = a, + b = b, + initial_state_source = initial_state_source, + initial_state_indices = initial_state_indices, + softplus_beta = softplus_beta, + softplus_threshold = softplus_threshold, + ) + + num_warmups = 3 + num_runs = 6 + + # warmup + for _ in range(num_warmups): + g.replay() + + # measure + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(num_runs): + g.replay() + end_event.record() + torch.cuda.synchronize() + latency = start_event.elapsed_time(end_event)/num_runs + + log_perf( + item_list=[{ + 'batch_size': batch_size, + 'isl': isl, + 'num_heads': num_heads, + 'head_k_dim': head_k_dim, + 'head_v_dim': head_v_dim, + 'num_value_heads': num_value_heads, + 'max_batch_size': max_batch_size, + 'latency': latency + }], + framework='TRTLLM', + version=tensorrt_llm.__version__, + device_name=torch.cuda.get_device_name(device), + op_name='gated_delta_rule_update', + kernel_source='default', + perf_filename=perf_filename + ) diff --git a/src/aiconfigurator/sdk/backends/base_backend.py b/src/aiconfigurator/sdk/backends/base_backend.py index 7bc778d3..e0558500 100644 --- a/src/aiconfigurator/sdk/backends/base_backend.py +++ b/src/aiconfigurator/sdk/backends/base_backend.py @@ -46,6 +46,7 @@ def run_static(self, step, default is 32. latency_correction_scale (float): the correction scale to adjust the latency, default is 1.0. corrected latency = latency * latency_correction_scale """ + def _run_context(batch_size: int, isl: int) -> dict[str, float]: context_latency_dict = defaultdict(float) diff --git a/src/aiconfigurator/sdk/backends/trtllm_backend.py b/src/aiconfigurator/sdk/backends/trtllm_backend.py index d5131b40..9ae3d38c 100644 --- a/src/aiconfigurator/sdk/backends/trtllm_backend.py +++ b/src/aiconfigurator/sdk/backends/trtllm_backend.py @@ -317,7 +317,8 @@ def _get_memory_usage(self, c_dict = {1:11, 2:6.5, 4:5, 8:5} activations = 2*num_tokens*h*c_dict[min(model.config.tp_size, 8)] activations = max(activations, 70*1024*1024) # minimum act - elif get_model_family(model.model_name) == 'MOE': + elif get_model_family(model.model_name) in ['MOE', 'QWEN3NEXT']: + # TODO: Qwen3Next has different activation memory calculation. c_dict = {1:22, 2:13, 4:10, 8:10} activations = 2*num_tokens*h*c_dict[min(model.config.tp_size, 8)] activations = max(activations, 70*1024*1024) # minimum act diff --git a/src/aiconfigurator/sdk/common.py b/src/aiconfigurator/sdk/common.py index 9bf4ea67..531075ce 100644 --- a/src/aiconfigurator/sdk/common.py +++ b/src/aiconfigurator/sdk/common.py @@ -23,6 +23,27 @@ class BlockConfig: ffn_no_op: bool = False num_inst: int = 0 +@dataclass(frozen=True) +class LinearAttentionConfig: + """ + Configuration for a single linear attention block in Qwen3Next. + + Attributes: + used_ratio (float): Used ratio of the linear attention block within all attention blocks + linear_conv_kernel_dim (int): Kernel dimension for the linear convolution + linear_key_head_dim (int): Head dimension for the linear key + linear_num_key_heads (int): Number of key heads for the linear attention + linear_num_value_heads (int): Number of value heads for the linear attention + linear_value_head_dim (int): Head dimension for the linear value + """ + used_ratio: float = 0.75 + linear_conv_kernel_dim: int = 4 + linear_key_head_dim: int = 128 + linear_num_key_heads: int = 16 + linear_num_value_heads: int = 32 + linear_value_head_dim: int = 128 + + """ Supported models model name: model_family,l,n,n_kv,d,hidden_size,inter_size,vocab,context,topk,num_experts,moe_inter_size,extra_params @@ -59,6 +80,7 @@ class BlockConfig: 'QWEN3_8B':['LLAMA', 36,32,8,128,32*128,12288,151936,40960, 0, 0, 0, None], 'QWEN3_235B':['MOE', 94,64,4,128,4096,12288,151936,40960, 8, 128, 1536, None], 'QWEN3_480B':['MOE', 62,96,8,128,6144,8192,151936,262144,8,160,2560, None], + 'QWEN3_NEXT_80B':['QWEN3NEXT', 48,16,2,256,2048,5120,151936,262144,10,512,512, LinearAttentionConfig(0.75, 4, 128, 16, 32, 128)], 'Nemotron_super_v1.1':['NEMOTRONNAS', 80, 64, 0, 128, 8192, 0, 128256, 131072, 0, 0, 0, [ BlockConfig(8, False, 5.25, False, 48), @@ -78,7 +100,7 @@ class BlockConfig: """ Model family for model definition """ -ModelFamily = {'GPT', 'LLAMA', 'MOE', 'DEEPSEEK', 'NEMOTRONNAS'} +ModelFamily = {'GPT', 'LLAMA', 'MOE', 'DEEPSEEK', 'NEMOTRONNAS', 'QWEN3NEXT'} """ All reduce strategy for trtllm custom allreduce diff --git a/src/aiconfigurator/sdk/models.py b/src/aiconfigurator/sdk/models.py index 3aa3d8df..db0fbce5 100755 --- a/src/aiconfigurator/sdk/models.py +++ b/src/aiconfigurator/sdk/models.py @@ -18,7 +18,7 @@ def get_model(model_name: str, model_config: config.ModelConfig, backend_name: s """ assert(model_name in common.SupportedModels), f"unsupport model {model_name}" model_family,l,n,n_kv,d,hidden,inter,vocab,context,topk,num_experts,moe_inter_size, extra_params = common.SupportedModels[model_name] - assert(model_family in common.ModelFamily), f"model is not in ModelFamily(GPT, LLAMA, MOE, DEEPSEEK, NEMOTRONNAS)" + assert(model_family in common.ModelFamily), f"model is not in ModelFamily(GPT, LLAMA, MOE, DEEPSEEK, NEMOTRONNAS, QWEN3NEXT)" if model_config.overwrite_num_layers > 0: l = model_config.overwrite_num_layers @@ -53,6 +53,13 @@ def get_model(model_name: str, model_config: config.ModelConfig, backend_name: s model_config) model.context_ops = extra_params model.generation_ops = extra_params + elif model_family == 'QWEN3NEXT': + model = Qwen3NextModel(topk, num_experts, moe_inter_size, \ + model_name, model_family, l, n, n_kv, d, \ + hidden, inter, vocab, context, \ + model_config) + model.context_ops = extra_params + model.generation_ops = extra_params return model @@ -68,7 +75,7 @@ def check_is_moe(model_name: str) -> bool: """ Check if the model is a MoE model. """ - return get_model_family(model_name) == 'MOE' or get_model_family(model_name) == 'DEEPSEEK' + return get_model_family(model_name) == 'MOE' or get_model_family(model_name) == 'DEEPSEEK' or get_model_family(model_name) == 'QWEN3NEXT' def calc_expectation(nextn: int, nextn_accept_rates: list[float]) -> float: """ @@ -732,7 +739,259 @@ def _ffn_mult_to_intermediate_size(self, ffn_mult: float) -> int: if inter_size % 256 == 0: return inter_size return inter_size + 256 - (inter_size % 256) - + + +class Qwen3NextModel(BaseModel): + """ + Qwen3Next model uses this model impl. + Currently Qwen3Next only has a series of 80B A3B models, which is similar to MOEModel but with different attention: + 1/4 of the layers are the same to MOE model, using self attention. + 3/4 of the layers are using linear attention with convolution 1d operation. + Some rules to follow, + Due to implementation, attn layer name needs to be context_attention or generation_attention, exact match is required. Same for logits_gemm. + + Refer to tensorrt_llm/_torch/models/modeling_qwen3_next.py for more details. + """ + def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> None: + super().__init__(*args) + assert self._nextn == 0, 'Qwen3Next only supports mtp=0' + + # make sure the parallel width is same + assert(self.config.tp_size * self.config.attention_dp_size == self.config.moe_tp_size * self.config.moe_ep_size), \ + f"tp_size ({self.config.tp_size}) * attention_dp_size ({self.config.attention_dp_size}) should be equal to moe_tp_size ({self.config.moe_tp_size}) * moe_ep_size ({self.config.moe_ep_size})" + + assert(num_experts >= self.config.moe_ep_size), f"ep size cannot be larger than num_experts {num_experts}" + assert(self.config.tp_size * self.config.attention_dp_size <= 256), f"moe ep size {self.config.moe_ep_size} * moe tp size {self.config.moe_tp_size} should not be larger than 256" + assert(self._num_layers % 4 == 0), f"num_layers {self._num_layers} should be divisible by 4" + + self._topk = topk + self._num_experts = num_experts + self._moe_inter_size = moe_inter_size + + self._power_law_alpha = 1.2 + + @property + def context_ops(self): + """ + Get the context(prefill) processing operations pipeline. + + Returns: + List[ops.Operation]: List of operations for processing context + sequences, including: + - embedding, + - attention blocks, + - FFN blocks, + - P2P communication, + - all reduce communication + - logits computation. + """ + return self._context_ops + + @context_ops.setter + def context_ops(self, linear_attention_config: common.LinearAttentionConfig): + """ + Set the context(prefill) processing operations pipeline based on linear attention configurations. + + Constructs a pipeline of operations for processing input context by creating operations + for each configured transformer block. The pipeline includes embedding lookup, + transformer blocks (with optional attention and FFN components), pipeline parallel + communication, and final logits computation. + + Args: + linear_attention_config (common.LinearAttentionConfig or list): Linear attention configuration + or empty list for initialization + """ + self._context_ops = [] + if not isinstance(linear_attention_config, common.LinearAttentionConfig): + return + + num_v_heads = linear_attention_config.linear_num_value_heads + num_k_heads = linear_attention_config.linear_num_key_heads + head_k_dim = linear_attention_config.linear_key_head_dim + head_v_dim = linear_attention_config.linear_value_head_dim + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + conv_kernel_size = linear_attention_config.linear_conv_kernel_dim + conv_dim = key_dim * 2 + value_dim + moe_quant_mode = self.config.moe_quant_mode + h = self._hidden_size + tp_size = self.config.tp_size + moe_tp_size = self.config.moe_tp_size + moe_ep_size = self.config.moe_ep_size + attention_dp_size = self.config.attention_dp_size + pp_size = self.config.pp_size + num_kv_heads_per_GPU = self._num_kv_heads_per_GPU + gemm_quant_mode = self.config.gemm_quant_mode + kvcache_quant_mode = self.config.kvcache_quant_mode + fmha_quant_mode = self.config.fmha_quant_mode + workload_distribution = self.config.workload_distribution + f"_{self._power_law_alpha}" + + # 1 embedding for all layers + # 1 norm before attention per layer + self._context_ops.extend([ops.Embedding(f'context_embedding', 1, self._vocab_size, h, 0.3), + ops.ElementWise(f'context_add_norm_1', self._num_layers, 2*h, 2*h, 0.8)]) + + # self attention + self._context_ops.extend([ + ops.GEMM(f'context_qkv_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, gemm_quant_mode), + ops.ContextAttention(f'context_attention', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode, fmha_quant_mode), + ops.GEMM(f'context_proj_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), h, self._num_heads*self._head_size//tp_size, gemm_quant_mode), + ]) + + # linear attention (Qwen3NextGatedDeltaNet) + self._context_ops.extend([ + # Input projections for qkvz and ba + ops.GEMM(f'context_qkvz_gemm', self._num_layers * linear_attention_config.used_ratio, (key_dim*2 + value_dim*2)//tp_size, h, gemm_quant_mode), + ops.GEMM(f'context_ba_gemm', self._num_layers * linear_attention_config.used_ratio, num_v_heads*2//tp_size, h, gemm_quant_mode), + # Conv1D and gated delta rule operations - weights handled internally + # Conv1DFn(name, scale_factor, conv_kernel_size, conv_dim, tp_size) - batch_size and isl from kwargs + ops.Conv1DFn(f'context_conv1d_fn', self._num_layers * linear_attention_config.used_ratio, conv_kernel_size, conv_dim, tp_size), + # ChunkGatedDeltaRule(name, scale_factor, num_heads, head_k_dim, head_v_dim, num_value_heads) - isl from kwargs + ops.ChunkGatedDeltaRule(f'context_chunk_gated_delta_rule', self._num_layers * linear_attention_config.used_ratio, num_k_heads, head_k_dim, head_v_dim, num_v_heads), + # Output projection + ops.GEMM(f'context_proj_gemm', self._num_layers * linear_attention_config.used_ratio, h, value_dim//tp_size, gemm_quant_mode), + ]) + + # 1 norm before MOE per layer + self._context_ops.extend([ + ops.ElementWise(f'context_add_norm_2', self._num_layers // linear_attention_config.used_ratio, 2*h, 2*h, 0.8)]) + + #router, only take it into account when num_experts >= 128 + if self._num_experts >= 128: + self._context_ops.extend([ + ops.GEMM(f'context_router_gemm', self._num_layers, self._num_experts, h, common.GEMMQuantMode.float16) + ]) + + # dispatch tokens to experts, moe calc and get tokens back + # Qwen3Next has one more shared expert. + self._context_ops.extend([ + ops.MoEDispatch(f'context_moe_pre_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, True), + ops.MoE(f'context_moe', self._num_layers, h, self._moe_inter_size, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, moe_quant_mode, workload_distribution, attention_dp_size), + ops.MoEDispatch(f'context_moe_post_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, False)]) + + self._context_ops.extend([ops.GEMM(f'context_logits_gemm', 1, self._vocab_size//tp_size, h, common.GEMMQuantMode.float16)]) + + # # # when tp_size=0, the comm part will be 0 + # self._context_ops.append(ops.AllReduce('context_ar_1', self._num_layers, h, tp_size)) + # self._context_ops.append(ops.AllReduce('context_ar_2', self._num_layers, h, tp_size)) + + # pp + pp_scale_factor = pp_size-1 + self._context_ops.append(ops.P2P('context_p2p', pp_scale_factor, h, pp_size)) + + @property + def generation_ops(self): + """ + Get the generation (decoding) operations pipeline. + + Returns: + List[ops.Operation]: List of operations for the decoding phase + including: + - embedding, + - attention blocks, + - FFN blocks, + - P2P communication, + - all reduce communication + - logits computation. + """ + return self._generation_ops + + @generation_ops.setter + def generation_ops(self, linear_attention_config: common.LinearAttentionConfig): + """ + Set the generation (decoding) operations pipeline based on linear attention configurations. + + Constructs a pipeline of operations for generating output tokens by creating operations + for each configured transformer block. The pipeline includes embedding lookup, + transformer blocks (with optional attention and FFN components), pipeline parallel + communication, and final logits computation. + + Args: + linear_attention_config (common.LinearAttentionConfig): Linear attention configuration + or empty list for initialization + """ + self._generation_ops = [] + if not isinstance(linear_attention_config, common.LinearAttentionConfig): + return + + num_v_heads = linear_attention_config.linear_num_value_heads + num_k_heads = linear_attention_config.linear_num_key_heads + head_k_dim = linear_attention_config.linear_key_head_dim + head_v_dim = linear_attention_config.linear_value_head_dim + key_dim = head_k_dim * num_k_heads + value_dim = head_v_dim * num_v_heads + conv_kernel_size = linear_attention_config.linear_conv_kernel_dim + conv_dim = key_dim * 2 + value_dim + moe_quant_mode = self.config.moe_quant_mode + h = self._hidden_size + tp_size = self.config.tp_size + moe_tp_size = self.config.moe_tp_size + moe_ep_size = self.config.moe_ep_size + attention_dp_size = self.config.attention_dp_size + pp_size = self.config.pp_size + num_kv_heads_per_GPU = self._num_kv_heads_per_GPU + gemm_quant_mode = self.config.gemm_quant_mode + kvcache_quant_mode = self.config.kvcache_quant_mode + fmha_quant_mode = self.config.fmha_quant_mode + workload_distribution = self.config.workload_distribution + f"_{self._power_law_alpha}" + + # 1 embedding for all layers + # 1 norm before attention per layer + self._generation_ops.extend([ops.Embedding(f'generation_embedding', 1, self._vocab_size, h, 0.3), + ops.ElementWise(f'generation_add_norm_1', self._num_layers, 2*h, 2*h, 0.8)]) + + # self attention + self._generation_ops.extend([ + ops.GEMM(f'generation_qkv_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads*self._head_size//tp_size+self._head_size*num_kv_heads_per_GPU*2, h, gemm_quant_mode), + ops.GenerationAttention(f'generation_attention', self._num_layers * (1 - linear_attention_config.used_ratio), self._num_heads//tp_size, num_kv_heads_per_GPU, kvcache_quant_mode, fmha_quant_mode), + ops.GEMM(f'generation_proj_gemm', self._num_layers * (1 - linear_attention_config.used_ratio), h, self._num_heads*self._head_size//tp_size, gemm_quant_mode), + ]) + + # linear attention (Qwen3NextGatedDeltaNet) + self._generation_ops.extend([ + # Input projections for qkvz and ba + ops.GEMM(f'generation_qkvz_gemm', self._num_layers * linear_attention_config.used_ratio, (key_dim*2 + value_dim*2)//tp_size, h, gemm_quant_mode), + ops.GEMM(f'generation_ba_gemm', self._num_layers * linear_attention_config.used_ratio, num_v_heads*2//tp_size, h, gemm_quant_mode), + # Conv1D and gated delta rule operations - weights handled internally + # TODO: for mixed steps, add ops.Conv1DFn(...) + # Conv1DUpdate(name, scale_factor, conv_kernel_size, conv_dim, tp_size) - batch_size and isl from kwargs + ops.Conv1DUpdate(f'generation_conv1d_update', self._num_layers * linear_attention_config.used_ratio, conv_kernel_size, conv_dim, tp_size), + # GatedDeltaRuleUpdate(name, scale_factor, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size) - batch_size and isl from kwargs + # max_batch_size is dynamic, need to determine appropriate value + ops.GatedDeltaRuleUpdate(f'generation_gated_delta_rule_update', self._num_layers * linear_attention_config.used_ratio, num_k_heads, head_k_dim, head_v_dim, num_v_heads, 1024), + # Output projection + ops.GEMM(f'generation_proj_gemm', self._num_layers * linear_attention_config.used_ratio, h, value_dim//tp_size, gemm_quant_mode), + ]) + + # 1 norm before MOE per layer + self._generation_ops.extend([ + ops.ElementWise(f'generation_add_norm_2', self._num_layers // linear_attention_config.used_ratio, 2*h, 2*h, 0.8)]) + + #router, only take it into account when num_experts >= 128 + if self._num_experts >= 128: + self._generation_ops.extend([ + ops.GEMM(f'generation_router_gemm', self._num_layers, self._num_experts, h, common.GEMMQuantMode.float16) + ]) + + # dispatch tokens to experts, moe calc and get tokens back + # Qwen3Next has one more shared expert. + self._generation_ops.extend([ + ops.MoEDispatch(f'generation_moe_pre_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, True), + ops.MoE(f'generation_moe', self._num_layers, h, self._moe_inter_size, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, moe_quant_mode, workload_distribution, attention_dp_size), + ops.MoEDispatch(f'generation_moe_post_dispatch', self._num_layers, h, self._topk, self._num_experts + 1, moe_tp_size, moe_ep_size, attention_dp_size, False) + ]) + # logits gemm + self._generation_ops.extend([ops.GEMM(f'generation_logits_gemm', 1, self._vocab_size//tp_size, h, common.GEMMQuantMode.float16)]) + + # # # when tp_size=0, the comm part will be 0 + # self._generation_ops.append(ops.AllReduce('generation_ar_1', self._num_layers, h, tp_size)) + # self._generation_ops.append(ops.AllReduce('generation_ar_2', self._num_layers, h, tp_size)) + + # pp + pp_scale_factor = pp_size-1 + self._generation_ops.append(ops.P2P('generation_p2p', pp_scale_factor, h, pp_size)) + + if __name__ == '__main__': # TODO, move to unit tests model = get_model('DEEPSEEK_V3', config.ModelConfig( diff --git a/src/aiconfigurator/sdk/operations.py b/src/aiconfigurator/sdk/operations.py index e69db25d..9072eb78 100755 --- a/src/aiconfigurator/sdk/operations.py +++ b/src/aiconfigurator/sdk/operations.py @@ -90,7 +90,8 @@ def __init__(self, name: str, scale_factor: float, n: int, k: int, quant_mode: c self._n = n self._k = k self._quant_mode = quant_mode - self._weights = self._n*self._k*quant_mode.value.memory + self._weights = self._n*self._k*quant_mode.value.memory + def query(self, database:PerfDatabase, **kwargs): x = kwargs.get('x') overwrite_quant_mode = kwargs.get('quant_mode', None) @@ -99,7 +100,7 @@ def query(self, database:PerfDatabase, **kwargs): return database.query_gemm(x, self._n, self._k, quant_mode)*self._scale_factor def get_weights(self, **kwargs): - return self._weights * self._scale_factor + return self._weights * self._scale_factor class MoE(Operation): """ @@ -577,4 +578,82 @@ def query(self, database:PerfDatabase, **kwargs): return database.query_context_mla_sglang(batch_size, isl, self._tp_size, self._kvcache_quant_mode, self._fmha_quant_mode, self._attn_backend) * self._scale_factor def get_weights(self, **kwargs): - return self._weights * self._scale_factor \ No newline at end of file + return self._weights * self._scale_factor + +class Conv1DFn(Operation): + """ + Conv1DFn operation. + """ + def __init__(self, name: str, scale_factor: float, conv_kernel_size: int, conv_dim: int, tp_size: int) -> None: + super().__init__(name, scale_factor) + self._conv_kernel_size = conv_kernel_size + self._conv_dim = conv_dim + self._tp_size = tp_size + self._weights = 0.0 + + def query(self, database:PerfDatabase, **kwargs): + batch_size = kwargs.get('batch_size') + isl = kwargs.get('s') + return database.query_conv1d_fn(batch_size, isl, self._conv_kernel_size, self._conv_dim, self._tp_size)*self._scale_factor + + def get_weights(self, **kwargs): + return self._weights * self._scale_factor + +class Conv1DUpdate(Operation): + """ + Conv1DUpdate operation. + """ + def __init__(self, name: str, scale_factor: float, conv_kernel_size: int, conv_dim: int, tp_size: int) -> None: + super().__init__(name, scale_factor) + self._conv_kernel_size = conv_kernel_size + self._conv_dim = conv_dim + self._tp_size = tp_size + self._weights = 0.0 + + def query(self, database:PerfDatabase, **kwargs): + batch_size = kwargs.get('batch_size') + isl = kwargs.get('s') + return database.query_conv1d_update(batch_size, isl, self._conv_kernel_size, self._conv_dim, self._tp_size)*self._scale_factor + + def get_weights(self, **kwargs): + return self._weights * self._scale_factor + +class ChunkGatedDeltaRule(Operation): + """ + Chunk gated delta rule operation. + """ + def __init__(self, name: str, scale_factor: float, num_heads: int, head_k_dim: int, head_v_dim: int, num_value_heads: int) -> None: + super().__init__(name, scale_factor) + self._num_heads = num_heads + self._head_k_dim = head_k_dim + self._head_v_dim = head_v_dim + self._num_value_heads = num_value_heads + self._weights = 0.0 + + def query(self, database:PerfDatabase, **kwargs): + isl = kwargs.get('s') + return database.query_chunk_gated_delta_rule(self._num_heads, self._head_k_dim, self._head_v_dim, self._num_value_heads, isl)*self._scale_factor + + def get_weights(self, **kwargs): + return self._weights * self._scale_factor + +class GatedDeltaRuleUpdate(Operation): + """ + Gated delta rule update operation. + """ + def __init__(self, name: str, scale_factor: float, num_heads: int, head_k_dim: int, head_v_dim: int, num_value_heads: int, max_batch_size: int) -> None: + super().__init__(name, scale_factor) + self._num_heads = num_heads + self._head_k_dim = head_k_dim + self._head_v_dim = head_v_dim + self._num_value_heads = num_value_heads + self._max_batch_size = max_batch_size + self._weights = 0.0 + + def query(self, database:PerfDatabase, **kwargs): + batch_size = kwargs.get('batch_size') + isl = kwargs.get('s') + return database.query_gated_delta_rule_update(batch_size, isl, self._num_heads, self._head_k_dim, self._head_v_dim, self._num_value_heads, self._max_batch_size)*self._scale_factor + + def get_weights(self, **kwargs): + return self._weights * self._scale_factor diff --git a/src/aiconfigurator/sdk/perf_database.py b/src/aiconfigurator/sdk/perf_database.py index 7e1853d1..9068a821 100755 --- a/src/aiconfigurator/sdk/perf_database.py +++ b/src/aiconfigurator/sdk/perf_database.py @@ -1959,7 +1959,7 @@ def get_sol(num_tokens: int, hidden_size: int, intermediate_size: int, quant_mod num_left, num_right = self._nearest_1d_point_helper(num_tokens, list(mlp_dict.keys()), inner_only=False) lat = self._interp_1d([num_left, num_right], [mlp_dict[num_left], mlp_dict[num_right]], num_tokens) return lat - + def query_deepep_ll(self, node_num: int, num_tokens: int, @@ -2014,6 +2014,224 @@ def get_sol(num_tokens: int, num_experts: int, topk: int, hidden_size: int) -> T data = self._deepep_normal_data[node_num][hidden_size][topk][num_experts] lat = self._interp_2d_linear(sms, num_tokens, data) return lat / 1000.0 + + def query_conv1d_fn(self, + batch_size: int, + isl: int, + conv_kernel_size: int, + conv_dim: int, + tp_size: int, + sol_mode: Optional[common.SOLMode] = None) -> float: + """ + Query the Conv1D Fn operation data. + + Args: + batch_size: Batch size + isl: Sequence length + conv_kernel_size: Size of the convolution kernel + conv_dim: Dimension of the convolution + tp_size: Tensor parallel size + sol_mode: SOL mode for theoretical performance calculation + + Returns: + Latency in milliseconds + """ + def get_sol(batch_size: int, isl: int, conv_kernel_size: int, conv_dim: int, tp_size: int) -> Tuple[float, float, float]: + """ + Get the sol time, sol math and sol mem for Conv1D Fn + """ + # Conv1D operations: batch_size * isl * (conv_dim // tp_size) * conv_kernel_size + ops = batch_size * isl * (conv_dim // tp_size) * conv_kernel_size * 2 # 2 for FMA + mem_bytes = 2 * ( # Assuming fp16/bf16 + batch_size * isl * (conv_dim // tp_size) + # Input + (conv_dim // tp_size) * conv_kernel_size + # Weights + batch_size * isl * (conv_dim // tp_size) # Output + ) + sol_math = ops / self.system_spec['gpu']['float16_tc_flops'] * 1000 + sol_mem = mem_bytes / self.system_spec['gpu']['mem_bw'] * 1000 + sol_time = max(sol_math, sol_mem) + return sol_time, sol_math, sol_mem + + if sol_mode is None: + sol_mode = self._default_sol_mode + if sol_mode == common.SOLMode.SOL: + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size)[0] + elif sol_mode == common.SOLMode.SOL_FULL: + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size) + else: + # TODO: Add actual data interpolation when measurement data is available + # For now, return SOL estimation + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size)[0] + + def query_conv1d_update(self, + batch_size: int, + isl: int, + conv_kernel_size: int, + conv_dim: int, + tp_size: int, + sol_mode: Optional[common.SOLMode] = None) -> float: + """ + Query the Conv1D Update operation data. + + Args: + batch_size: Batch size + isl: Sequence length + conv_kernel_size: Size of the convolution kernel + conv_dim: Dimension of the convolution + tp_size: Tensor parallel size + sol_mode: SOL mode for theoretical performance calculation + + Returns: + Latency in milliseconds + """ + def get_sol(batch_size: int, isl: int, conv_kernel_size: int, conv_dim: int, tp_size: int) -> Tuple[float, float, float]: + """ + Get the sol time, sol math and sol mem for Conv1D Update + """ + # Conv1D update is typically lighter than full conv1d_fn + ops = batch_size * isl * (conv_dim // tp_size) * conv_kernel_size * 2 # 2 for FMA + mem_bytes = 2 * ( # Assuming fp16/bf16 + batch_size * isl * (conv_dim // tp_size) + # Input + (conv_dim // tp_size) * conv_kernel_size + # Weights + batch_size * isl * (conv_dim // tp_size) # Output + ) + sol_math = ops / self.system_spec['gpu']['float16_tc_flops'] * 1000 + sol_mem = mem_bytes / self.system_spec['gpu']['mem_bw'] * 1000 + sol_time = max(sol_math, sol_mem) + return sol_time, sol_math, sol_mem + + if sol_mode is None: + sol_mode = self._default_sol_mode + if sol_mode == common.SOLMode.SOL: + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size)[0] + elif sol_mode == common.SOLMode.SOL_FULL: + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size) + else: + # TODO: Add actual data interpolation when measurement data is available + # For now, return SOL estimation + return get_sol(batch_size, isl, conv_kernel_size, conv_dim, tp_size)[0] + + def query_chunk_gated_delta_rule(self, + num_heads: int, + head_k_dim: int, + head_v_dim: int, + num_value_heads: int, + isl: int, + sol_mode: Optional[common.SOLMode] = None) -> float: + """ + Query the Chunk Gated Delta Rule operation data. + + Args: + num_heads: Number of heads + head_k_dim: Dimension of the key heads + head_v_dim: Dimension of the value heads + num_value_heads: Number of value heads + isl: Sequence length + sol_mode: SOL mode for theoretical performance calculation + + Returns: + Latency in milliseconds + """ + def get_sol(num_heads: int, head_k_dim: int, head_v_dim: int, num_value_heads: int, isl: int) -> Tuple[float, float, float]: + """ + Get the sol time, sol math and sol mem for Chunk Gated Delta Rule + """ + # Gated delta rule involves attention-like operations + # Operations: q*k^T, gating, and weighted sum with values + ops = ( + num_heads * isl * isl * head_k_dim * 2 + # q*k^T + num_heads * isl * isl * 2 + # gating operations + num_value_heads * isl * isl * head_v_dim * 2 # weighted sum with values + ) + mem_bytes = 2 * ( # Assuming fp16/bf16 + num_heads * isl * head_k_dim + # Q + num_heads * isl * head_k_dim + # K + num_value_heads * isl * head_v_dim + # V + num_heads * isl + # gate + num_heads * isl + # beta + num_value_heads * isl * head_v_dim # output + ) + sol_math = ops / self.system_spec['gpu']['float16_tc_flops'] * 1000 + sol_mem = mem_bytes / self.system_spec['gpu']['mem_bw'] * 1000 + sol_time = max(sol_math, sol_mem) + return sol_time, sol_math, sol_mem + + if sol_mode is None: + sol_mode = self._default_sol_mode + if sol_mode == common.SOLMode.SOL: + return get_sol(num_heads, head_k_dim, head_v_dim, num_value_heads, isl)[0] + elif sol_mode == common.SOLMode.SOL_FULL: + return get_sol(num_heads, head_k_dim, head_v_dim, num_value_heads, isl) + else: + # TODO: Add actual data interpolation when measurement data is available + # For now, return SOL estimation + return get_sol(num_heads, head_k_dim, head_v_dim, num_value_heads, isl)[0] + + def query_gated_delta_rule_update(self, + batch_size: int, + isl: int, + num_heads: int, + head_k_dim: int, + head_v_dim: int, + num_value_heads: int, + max_batch_size: int, + sol_mode: Optional[common.SOLMode] = None) -> float: + """ + Query the Gated Delta Rule Update operation data. + Args: + batch_size: Batch size + isl: Sequence length + num_heads: Number of heads + head_k_dim: Dimension of the key heads + head_v_dim: Dimension of the value heads + num_value_heads: Number of value heads + max_batch_size: Maximum batch size + sol_mode: SOL mode for theoretical performance calculation + + Returns: + Latency in milliseconds + """ + def get_sol(batch_size: int, isl: int, num_heads: int, head_k_dim: int, + head_v_dim: int, num_value_heads: int, max_batch_size: int) -> Tuple[float, float, float]: + """ + Get the sol time, sol math and sol mem for Gated Delta Rule Update + """ + # Fused sigmoid gating delta rule update involves state updates + ops = ( + batch_size * isl * num_heads * head_k_dim * 2 + # q processing + batch_size * isl * num_heads * head_k_dim * 2 + # k processing + batch_size * isl * num_value_heads * head_v_dim * 2 + # v processing + batch_size * isl * num_heads * num_value_heads * 2 + # gating operations + max_batch_size * num_heads * num_value_heads * head_k_dim * head_v_dim * 2 # state operations + ) + mem_bytes = 2 * ( # Assuming fp16/bf16 + num_heads * num_value_heads + # A_log + num_heads * num_value_heads + # dt_bias + batch_size * isl * num_heads * head_k_dim + # q + batch_size * isl * num_heads * head_k_dim + # k + batch_size * isl * num_value_heads * head_v_dim + # v + batch_size * isl * num_heads * num_value_heads + # a + batch_size * isl * num_heads * num_value_heads + # b + max_batch_size * num_heads * num_value_heads * head_k_dim * head_v_dim + # initial_state_source + batch_size # initial_state_indices + ) + sol_math = ops / self.system_spec['gpu']['float16_tc_flops'] * 1000 + sol_mem = mem_bytes / self.system_spec['gpu']['mem_bw'] * 1000 + sol_time = max(sol_math, sol_mem) + return sol_time, sol_math, sol_mem + + if sol_mode is None: + sol_mode = self._default_sol_mode + if sol_mode == common.SOLMode.SOL: + return get_sol(batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size)[0] + elif sol_mode == common.SOLMode.SOL_FULL: + return get_sol(batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size) + else: + # TODO: Add actual data interpolation when measurement data is available + # For now, return SOL estimation + return get_sol(batch_size, isl, num_heads, head_k_dim, head_v_dim, num_value_heads, max_batch_size)[0] + + if __name__ == '__main__': database_dict = get_all_databases() \ No newline at end of file