Skip to content

Commit f25e519

Browse files
committed
add fixes
1 parent 2f478bb commit f25e519

File tree

5 files changed

+17
-56
lines changed

5 files changed

+17
-56
lines changed

projects/eudsl-python-extras/mlir/extras/dialects/llvm/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14
from typing import Union, Optional
25

36
from ...util import infer_mlir_type

projects/eudsl-python-extras/mlir/extras/dialects/llvm/util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14
import warnings
25

36
# noinspection PyUnresolvedReferences
Lines changed: 6 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4-
from . import arith
54
from ..util import get_user_code_loc
6-
5+
from ... import ir
76
from ...dialects._ods_common import (
87
_dispatch_mixed_values,
98
_cext,
@@ -16,54 +15,13 @@
1615

1716
# noinspection PyUnresolvedReferences
1817
from ...dialects.rocdl import *
19-
from ...dialects._rocdl_ops_gen import _Dialect
20-
from ... import ir
21-
22-
23-
@_cext.register_operation(_Dialect, replace=True)
24-
class WMMA_F16_16X16X16_F16(ir.OpView):
25-
OPERATION_NAME = "rocdl.wmma.f16.16x16x16.f16"
26-
27-
_ODS_REGIONS = (0, True)
2818

29-
def __init__(self, res, args, *, loc=None, ip=None):
30-
operands = []
31-
results = []
32-
attributes = {}
33-
regions = None
34-
operands.extend(get_op_results_or_values(args))
35-
_ods_context = get_default_loc_context(loc)
36-
results.append(res)
37-
_ods_successors = None
38-
super().__init__(
39-
self.OPERATION_NAME,
40-
self._ODS_REGIONS,
41-
self._ODS_OPERAND_SEGMENTS,
42-
self._ODS_RESULT_SEGMENTS,
43-
attributes=attributes,
44-
results=results,
45-
operands=operands,
46-
successors=_ods_successors,
47-
regions=regions,
48-
loc=loc,
49-
ip=ip,
50-
)
51-
52-
@property
53-
def args(self):
54-
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
55-
return self.operation.operands[0 : 0 + _ods_variadic_group_length]
56-
57-
@property
58-
def res(self):
59-
return self.operation.results[0]
60-
61-
62-
wmma_f16_16x16x16_f16_ = wmma_f16_16x16x16_f16
19+
_wmma_f16_16x16x16_f16 = wmma_f16_16x16x16_f16
6320

6421

6522
def wmma_f16_16x16x16_f16(A, B, C, *, opsel=False, loc=None, ip=None) -> ir.Value:
23+
if loc is None:
24+
loc = get_user_code_loc()
25+
6626
v16 = ir.VectorType.get((16,), ir.F16Type.get())
67-
return wmma_f16_16x16x16_f16_(
68-
res=v16, A=A, B=B, C=C, opsel=opsel, loc=loc, ip=ip
69-
).result
27+
return _wmma_f16_16x16x16_f16(v16, A, B, C, opsel=opsel, loc=loc, ip=ip).result

projects/eudsl-python-extras/mlir/extras/dialects/transform.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,6 @@ def _structured_bufferize_to_allocation(
375375
memory_space = StringAttr.get(memory_space)
376376

377377
return __structured_bufferize_to_allocation(
378-
# allocated_buffer=transform_any_value_t(),
379-
# new_ops=transform_any_op_t(),
380378
target=target,
381379
memory_space=memory_space,
382380
memcpy_op=memcpy_op,

projects/eudsl-python-extras/tests/dialect/test_gpu.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -625,10 +625,9 @@ def fun2[foo, bar, A: foo + bar]():
625625
)
626626

627627

628-
@pytest.mark.xfail(reason="TODO: reification of ")
629628
@pytest.mark.skipif(
630-
sys.version_info < (3, 12) or platform.system() == "Windows",
631-
reason="requires python3.12 or higher (and windows can't find the source file)",
629+
sys.version_info < (3, 13) or platform.system() == "Windows",
630+
reason="requires python3.13 or higher (and windows can't find the source file)",
632631
)
633632
def test_generic_type_var_closure_patching_dependent_generics(ctx: MLIRContext):
634633
# dodge <3.12 parser that doesn't support square brackets generics
@@ -645,9 +644,9 @@ def test_plain[
645644
K,
646645
N,
647646
dtype,
648-
A_t: T.memref(M, K, dtype),
649-
B_t: T.memref(K, N, dtype),
650-
C_t: T.memref(M, N, dtype),
647+
A_t = T.memref(M, K, dtype),
648+
B_t = T.memref(K, N, dtype),
649+
C_t = T.memref(M, N, dtype),
651650
](A: A_t, B: B_t, C: C_t):
652651
one = arith.constant(1.0, type=dtype)
653652

0 commit comments

Comments
 (0)