|
4 | 4 | import fcntl |
5 | 5 | import json |
6 | 6 | import logging |
| 7 | +import math |
7 | 8 | import multiprocessing as mp |
8 | 9 | import os |
9 | 10 | import signal |
@@ -264,3 +265,171 @@ def log_perf( |
264 | 265 | f.write(header_prefix + "\n") |
265 | 266 |
|
266 | 267 | f.write(content_prefix + "\n") |
| 268 | + |
| 269 | + |
| 270 | +# Helper functions for MoE |
| 271 | +def balanced_logits(num_tokens, num_experts, topk): |
| 272 | + import torch |
| 273 | + import torch.nn.functional as F |
| 274 | + |
| 275 | + # h_selected_experts = -torch.ones([num_tokens, topk]).to(torch.device(device)) |
| 276 | + h_selected_experts = -torch.ones([num_tokens, topk]) |
| 277 | + stride = math.ceil(num_experts / topk) |
| 278 | + |
| 279 | + for token_i in range(num_tokens): |
| 280 | + for i in range(topk): |
| 281 | + if num_tokens >= stride: |
| 282 | + h_selected_experts[token_i][i] = (token_i + i * stride) % num_experts |
| 283 | + else: |
| 284 | + h_selected_experts[token_i][i] = (token_i * stride / num_tokens + i * stride) % num_experts |
| 285 | + |
| 286 | + expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) |
| 287 | + router_logits = F.softmax(expert_map.bfloat16(), dim=1) |
| 288 | + return router_logits |
| 289 | + |
| 290 | + |
| 291 | +def sample_power_law(size, alpha, xmin, xmax): |
| 292 | + import torch |
| 293 | + |
| 294 | + u = torch.rand(size) |
| 295 | + inv_cdf = ((xmax ** (1 - alpha) - xmin ** (1 - alpha)) * u + xmin ** (1 - alpha)) ** (1 / (1 - alpha)) |
| 296 | + return inv_cdf |
| 297 | + |
| 298 | + |
| 299 | +def power_law_logits_v3(num_tokens, num_experts, topk, ep, alpha): |
| 300 | + import torch |
| 301 | + import torch.nn.functional as F |
| 302 | + |
| 303 | + if num_tokens * topk > num_experts: |
| 304 | + num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8) |
| 305 | + else: |
| 306 | + num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2) |
| 307 | + |
| 308 | + target_sum = num_tokens * topk |
| 309 | + |
| 310 | + original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum() |
| 311 | + |
| 312 | + target_distribution = original_distribution * target_sum |
| 313 | + |
| 314 | + num_tokens_per_expert = torch.round(target_distribution).to(torch.int64) |
| 315 | + |
| 316 | + current_sum = num_tokens_per_expert.sum().item() |
| 317 | + delta = target_sum - current_sum |
| 318 | + if delta != 0: |
| 319 | + sorted_indices = torch.argsort(num_tokens_per_expert, descending=True) |
| 320 | + |
| 321 | + if delta > 0: |
| 322 | + for i in range(delta): |
| 323 | + expert_idx = sorted_indices[i % len(sorted_indices)] |
| 324 | + num_tokens_per_expert[expert_idx] += 1 |
| 325 | + else: |
| 326 | + for i in range(-delta): |
| 327 | + expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1] |
| 328 | + if num_tokens_per_expert[expert_idx] > 0: |
| 329 | + num_tokens_per_expert[expert_idx] -= 1 |
| 330 | + else: |
| 331 | + num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1 |
| 332 | + |
| 333 | + if len(num_tokens_per_expert) > 1: |
| 334 | + sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0] |
| 335 | + assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted" |
| 336 | + |
| 337 | + with torch.no_grad(): |
| 338 | + conv1d = torch.nn.Conv1d( |
| 339 | + in_channels=1, |
| 340 | + out_channels=1, |
| 341 | + kernel_size=num_experts // ep, |
| 342 | + stride=num_experts // ep, |
| 343 | + padding=0, |
| 344 | + bias=False, |
| 345 | + ) |
| 346 | + conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]) |
| 347 | + conv1d.weight.copy_(conv1d_weights) |
| 348 | + |
| 349 | + res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float()) |
| 350 | + max_ep_idx = torch.argmax(res).item() |
| 351 | + |
| 352 | + if max_ep_idx != 0: |
| 353 | + ep_group_size = num_experts // ep |
| 354 | + num_tokens_per_expert_reshaped = num_tokens_per_expert.view(ep, ep_group_size) |
| 355 | + num_tokens_per_expert_reshaped[0], num_tokens_per_expert_reshaped[max_ep_idx] = ( |
| 356 | + num_tokens_per_expert_reshaped[max_ep_idx].clone(), |
| 357 | + num_tokens_per_expert_reshaped[0].clone(), |
| 358 | + ) |
| 359 | + num_tokens_per_expert = num_tokens_per_expert_reshaped.view(-1) |
| 360 | + |
| 361 | + aic_debug = int(os.getenv("AIC_DEBUG", "0")) |
| 362 | + if aic_debug == 1: |
| 363 | + print("num_tokens_per_expert", num_tokens_per_expert, num_tokens_per_expert.sum().item()) |
| 364 | + |
| 365 | + _, num_tokens_per_expert_sorted_index = torch.sort(num_tokens_per_expert, descending=True) |
| 366 | + expert_assignments = [] |
| 367 | + num_tokens_per_expert_sorted_index_lists = num_tokens_per_expert_sorted_index.tolist() |
| 368 | + for expert_id in num_tokens_per_expert_sorted_index_lists: |
| 369 | + expert_assignments.extend([expert_id] * num_tokens_per_expert[expert_id]) |
| 370 | + |
| 371 | + expert_assignments = torch.tensor(expert_assignments, dtype=torch.long) |
| 372 | + h_selected_experts = expert_assignments.reshape(topk, num_tokens).T |
| 373 | + |
| 374 | + expert_map = F.one_hot(h_selected_experts.long(), num_classes=num_experts).sum(1) |
| 375 | + router_logits = F.softmax(expert_map.bfloat16(), dim=1) |
| 376 | + return router_logits |
| 377 | + |
| 378 | + |
| 379 | +# NOTE: power_law_logits_v4 was copied from power_law_logits_v3 and |
| 380 | +# modified to restrict max tokens per expert to be less than num_tokens |
| 381 | +def power_law_logits_v4(num_tokens, num_experts, topk, ep, alpha): |
| 382 | + import torch |
| 383 | + |
| 384 | + """Generate power law distribution for token assignment to experts""" |
| 385 | + while True: |
| 386 | + if num_tokens * topk > num_experts: |
| 387 | + num_tokens_per_expert = sample_power_law(num_experts, alpha, 1, num_tokens * 0.8) |
| 388 | + else: |
| 389 | + num_tokens_per_expert = sample_power_law(num_experts, alpha, 0.01, 2) |
| 390 | + target_sum = num_tokens * topk |
| 391 | + |
| 392 | + original_distribution = num_tokens_per_expert / num_tokens_per_expert.sum() |
| 393 | + |
| 394 | + target_distribution = original_distribution * target_sum |
| 395 | + |
| 396 | + num_tokens_per_expert = torch.round(target_distribution).to(torch.int64) |
| 397 | + |
| 398 | + current_sum = num_tokens_per_expert.sum().item() |
| 399 | + delta = target_sum - current_sum |
| 400 | + if delta != 0: |
| 401 | + sorted_indices = torch.argsort(num_tokens_per_expert, descending=True) |
| 402 | + |
| 403 | + if delta > 0: |
| 404 | + for i in range(delta): |
| 405 | + expert_idx = sorted_indices[i % len(sorted_indices)] |
| 406 | + num_tokens_per_expert[expert_idx] += 1 |
| 407 | + else: |
| 408 | + for i in range(-delta): |
| 409 | + expert_idx = sorted_indices[-(i % len(sorted_indices)) - 1] |
| 410 | + if num_tokens_per_expert[expert_idx] > 0: |
| 411 | + num_tokens_per_expert[expert_idx] -= 1 |
| 412 | + else: |
| 413 | + num_tokens_per_expert[torch.argmax(num_tokens_per_expert)] -= 1 |
| 414 | + |
| 415 | + if len(num_tokens_per_expert) > 1: |
| 416 | + sorted_tokens = torch.sort(num_tokens_per_expert, descending=True)[0] |
| 417 | + assert sorted_tokens[0] >= sorted_tokens[-1], "Power law distribution pattern disrupted" |
| 418 | + |
| 419 | + with torch.no_grad(): |
| 420 | + conv1d = torch.nn.Conv1d( |
| 421 | + in_channels=1, |
| 422 | + out_channels=1, |
| 423 | + kernel_size=num_experts // ep, |
| 424 | + stride=num_experts // ep, |
| 425 | + padding=0, |
| 426 | + bias=False, |
| 427 | + ) |
| 428 | + conv1d_weights = torch.tensor([1 for _ in range(num_experts // ep)]) |
| 429 | + conv1d.weight.copy_(conv1d_weights) |
| 430 | + |
| 431 | + res = conv1d(num_tokens_per_expert.unsqueeze(0).unsqueeze(0).float()) |
| 432 | + max_ep_idx = torch.argmax(res).item() |
| 433 | + num_tokens_per_expert_rank0 = num_tokens_per_expert.view(ep, num_experts // ep)[max_ep_idx].view(-1) |
| 434 | + if max(num_tokens_per_expert_rank0) <= num_tokens: |
| 435 | + return num_tokens_per_expert_rank0 |
0 commit comments