|
15 | 15 | from .. import exc |
16 | 16 | from .._compat import get_tensor_descriptor_fn_name |
17 | 17 | from .ast_extension import expr_from_string |
| 18 | +from .ast_extension import statement_from_string |
18 | 19 | from .compile_environment import CompileEnvironment |
19 | 20 | from .device_function import DeviceFunction |
20 | 21 | from .host_function import HostFunction |
@@ -385,21 +386,118 @@ def codegen_store( |
385 | 386 | indexing = BlockedSubscriptIndexing.create(state, fake_tensor, subscript) |
386 | 387 |
|
387 | 388 | # Apply permutation to the value being stored if needed |
388 | | - desc_arg = indexing.tensor_descriptor_arg(state) |
| 389 | + # desc_arg = indexing.tensor_descriptor_arg(state, subtile=True) |
389 | 390 | store_value = indexing.reshape_store(state, value) |
390 | 391 |
|
391 | | - if desc_arg.permutation is not None: |
392 | | - # Apply permutation to the value |
393 | | - store_value = expr_from_string( |
394 | | - f"tl.permute({{store_val}}, {desc_arg.permutation!r})", |
395 | | - store_val=store_value, |
| 392 | + # if desc_arg.permutation is not None: |
| 393 | + # # Apply permutation to the value |
| 394 | + # store_value = expr_from_string( |
| 395 | + # f"tl.permute({{store_val}}, {desc_arg.permutation!r})", |
| 396 | + # store_val=store_value, |
| 397 | + # ) |
| 398 | + |
| 399 | + if ( |
| 400 | + subtile_store := self._codegen_epilogue_subtile_store( |
| 401 | + state, fake_tensor, indexing, store_value |
396 | 402 | ) |
397 | | - |
| 403 | + ) is not None: |
| 404 | + return subtile_store |
| 405 | + |
398 | 406 | return expr_from_string( |
399 | 407 | f"{indexing.tensor_descriptor(state)}.store({indexing.offsets_str_permuted(state)}, {{value}})", |
400 | 408 | value=store_value, |
401 | 409 | ) |
402 | 410 |
|
| 411 | + def _codegen_epilogue_subtile_store( |
| 412 | + self, |
| 413 | + state: CodegenState, |
| 414 | + fake_tensor: torch.Tensor, |
| 415 | + indexing: BlockedSubscriptIndexing, |
| 416 | + store_value: ast.AST, |
| 417 | + ) -> ast.AST | None: |
| 418 | + # Currently support 2D tiles without permutations |
| 419 | + if len(indexing.block_shape) != 2 or len(indexing.offsets) != 2: |
| 420 | + return None |
| 421 | + |
| 422 | + env = CompileEnvironment.current() |
| 423 | + block_m, block_n = indexing.block_shape |
| 424 | + try: |
| 425 | + block_n_hint = env.size_hint(block_n) |
| 426 | + except Exception: |
| 427 | + return None |
| 428 | + |
| 429 | + if block_n_hint % 2 != 0: |
| 430 | + return None |
| 431 | + |
| 432 | + device_fn = state.device_function |
| 433 | + codegen = state.codegen |
| 434 | + |
| 435 | + block_m_str = device_fn.literal_expr(block_m) |
| 436 | + block_n_str = device_fn.literal_expr(block_n) |
| 437 | + indexing.block_shape[1] //= 2 |
| 438 | + desc_arg = indexing.tensor_descriptor_arg(state) |
| 439 | + |
| 440 | + if desc_arg.permutation is not None: |
| 441 | + return None |
| 442 | + |
| 443 | + |
| 444 | + block_n_half_str = f"({block_n_str} // 2)" |
| 445 | + |
| 446 | + # Lift the store value into a temporary variable for reuse |
| 447 | + acc_var = codegen.lift(store_value, prefix="acc") |
| 448 | + |
| 449 | + reshape_expr = expr_from_string( |
| 450 | + "tl.reshape({acc}, [{dim_m}, 2, {dim_half}])", |
| 451 | + acc=acc_var, |
| 452 | + dim_m=expr_from_string(block_m_str), |
| 453 | + dim_half=expr_from_string(block_n_half_str), |
| 454 | + ) |
| 455 | + reshape_var = codegen.lift(reshape_expr, prefix="acc") |
| 456 | + |
| 457 | + permute_expr = expr_from_string( |
| 458 | + "tl.permute({acc}, [0, 2, 1])", |
| 459 | + acc=reshape_var, |
| 460 | + ) |
| 461 | + permute_var = codegen.lift(permute_expr, prefix="acc") |
| 462 | + |
| 463 | + acc0_name = codegen.tmpvar(prefix="acc") |
| 464 | + acc1_name = codegen.tmpvar(prefix="acc") |
| 465 | + codegen.add_statement( |
| 466 | + statement_from_string( |
| 467 | + f"{acc0_name}, {acc1_name} = tl.split({{acc}})", |
| 468 | + acc=permute_var, |
| 469 | + ) |
| 470 | + ) |
| 471 | + acc0 = expr_from_string(acc0_name) |
| 472 | + acc1 = expr_from_string(acc1_name) |
| 473 | + |
| 474 | + desc_name = indexing.tensor_descriptor(state) |
| 475 | + offset0 = expr_from_string(indexing.offsets[0]) |
| 476 | + offset1 = expr_from_string(indexing.offsets[1]) |
| 477 | + |
| 478 | + # First subtile store |
| 479 | + codegen.add_statement( |
| 480 | + statement_from_string( |
| 481 | + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", |
| 482 | + off0=offset0, |
| 483 | + off1=offset1, |
| 484 | + value=acc0, |
| 485 | + ) |
| 486 | + ) |
| 487 | + |
| 488 | + offset1_shifted = expr_from_string( |
| 489 | + "({offset} + {half})", |
| 490 | + offset=expr_from_string(indexing.offsets[1]), |
| 491 | + half=expr_from_string(block_n_half_str), |
| 492 | + ) |
| 493 | + |
| 494 | + # Emit second subtile store as the expression returned to the caller |
| 495 | + return expr_from_string( |
| 496 | + f"{desc_name}.store([{{off0}}, {{off1}}], {{value}})", |
| 497 | + off0=offset0, |
| 498 | + off1=offset1_shifted, |
| 499 | + value=acc1, |
| 500 | + ) |
403 | 501 |
|
404 | 502 | class StackIndexingStrategy: |
405 | 503 | """ |
|
0 commit comments