Skip to content

Commit f06ddf7

Browse files
Add llvm, xform, gpu dialects
Adds tests and dialects for rocdl, llvm, gpu and transform.
1 parent 352594f commit f06ddf7

File tree

9 files changed

+3372
-3
lines changed

9 files changed

+3372
-3
lines changed

projects/eudsl-python-extras/mlir/extras/dialects/gpu.py

Lines changed: 602 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Union, Optional
2+
3+
from ...util import infer_mlir_type
4+
5+
# noinspection PyUnresolvedReferences
6+
from ....dialects.llvm import *
7+
from ....ir import Type, Value, IntegerAttr, FloatAttr
8+
from ....dialects._ods_common import get_op_result_or_op_results
9+
10+
ValueRef = Value
11+
12+
13+
def llvm_ptr_t():
14+
return Type.parse("!llvm.ptr")
15+
16+
17+
try:
18+
from . import amdgcn
19+
except ImportError:
20+
pass
21+
22+
23+
def mlir_constant(
24+
value: Union[int, float, bool], type: Optional[Type] = None, *, loc=None, ip=None
25+
) -> Value:
26+
if type is None:
27+
type = infer_mlir_type(value, vector=False)
28+
29+
if isinstance(value, int):
30+
value = IntegerAttr.get(type, value)
31+
elif isinstance(value, float):
32+
value = FloatAttr.get(type, value)
33+
else:
34+
raise NotImplementedError(f"{value} is not a valid type")
35+
36+
return get_op_result_or_op_results(
37+
ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
38+
)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import warnings
2+
3+
# noinspection PyUnresolvedReferences
4+
from ....dialects.llvm import *
5+
from ....ir import Type, F16Type, F32Type, F64Type, BF16Type, IntegerType
6+
7+
try:
8+
from llvm import intrinsic_is_overloaded, intrinsic_get_type, print_type_to_string
9+
from llvm import types_
10+
from llvm.context import context as llvm_context
11+
except ImportError:
12+
warnings.warn(
13+
"llvm bindings not installed; call_intrinsic won't work without supplying return type explicitly"
14+
)
15+
16+
17+
def mlir_type_to_llvm_type(mlir_type, llvm_ctx):
18+
if F16Type.isinstance(mlir_type):
19+
return types_.half_type_in_context(llvm_ctx)
20+
if F32Type.isinstance(mlir_type):
21+
return types_.float_type_in_context(llvm_ctx)
22+
if F64Type.isinstance(mlir_type):
23+
return types_.double_type_in_context(llvm_ctx)
24+
if BF16Type.isinstance(mlir_type):
25+
return types_.b_float_type_in_context(llvm_ctx)
26+
if IntegerType.isinstance(mlir_type):
27+
return types_.int_type_in_context(llvm_ctx, mlir_type.width)
28+
29+
raise NotImplementedError(f"{mlir_type} is not supported")
30+
31+
32+
def llvm_type_str_to_mlir_type(llvm_type: str):
33+
if llvm_type.startswith("<"):
34+
return Type.parse(f"vector{llvm_type}")
35+
if llvm_type == "float":
36+
return F32Type.get()
37+
raise NotImplementedError(f"{llvm_type} is not supported")
38+
39+
40+
_call_intrinsic = call_intrinsic
41+
42+
43+
def call_intrinsic(*args, **kwargs):
44+
intr_id = kwargs.pop("intr_id")
45+
intr_name = kwargs.pop("intr_name")
46+
mlir_ret_type = kwargs.pop("return_type", None)
47+
if mlir_ret_type:
48+
return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])
49+
50+
is_overloaded = kwargs.pop("is_overloaded", None)
51+
if is_overloaded is None:
52+
is_overloaded = intrinsic_is_overloaded(intr_id)
53+
with llvm_context() as ctx:
54+
types = []
55+
if is_overloaded:
56+
types = [mlir_type_to_llvm_type(a.type, ctx.context) for a in args]
57+
intr_decl_fn_ty = intrinsic_get_type(ctx.context, intr_id, types)
58+
59+
ret_type_str = print_type_to_string(intr_decl_fn_ty).split(" (")[0].strip()
60+
mlir_ret_type = None
61+
if ret_type_str:
62+
mlir_ret_type = llvm_type_str_to_mlir_type(ret_type_str)
63+
64+
return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])
65+
66+
67+
call_intrinsic_ = call_intrinsic
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
from . import arith
5+
from ..util import get_user_code_loc
6+
7+
from ...dialects._ods_common import (
8+
_dispatch_mixed_values,
9+
_cext,
10+
get_op_results_or_values,
11+
get_default_loc_context,
12+
get_op_result_or_op_results,
13+
get_default_loc_context,
14+
segmented_accessor,
15+
)
16+
17+
# noinspection PyUnresolvedReferences
18+
from ...dialects.rocdl import *
19+
from ...dialects._rocdl_ops_gen import _Dialect
20+
from ... import ir
21+
22+
23+
@_cext.register_operation(_Dialect, replace=True)
24+
class WMMA_F16_16X16X16_F16(ir.OpView):
25+
OPERATION_NAME = "rocdl.wmma.f16.16x16x16.f16"
26+
27+
_ODS_REGIONS = (0, True)
28+
29+
def __init__(self, res, args, *, loc=None, ip=None):
30+
operands = []
31+
results = []
32+
attributes = {}
33+
regions = None
34+
operands.extend(get_op_results_or_values(args))
35+
_ods_context = get_default_loc_context(loc)
36+
results.append(res)
37+
_ods_successors = None
38+
super().__init__(
39+
self.OPERATION_NAME,
40+
self._ODS_REGIONS,
41+
self._ODS_OPERAND_SEGMENTS,
42+
self._ODS_RESULT_SEGMENTS,
43+
attributes=attributes,
44+
results=results,
45+
operands=operands,
46+
successors=_ods_successors,
47+
regions=regions,
48+
loc=loc,
49+
ip=ip,
50+
)
51+
52+
@property
53+
def args(self):
54+
_ods_variadic_group_length = len(self.operation.operands) - 1 + 1
55+
return self.operation.operands[0 : 0 + _ods_variadic_group_length]
56+
57+
@property
58+
def res(self):
59+
return self.operation.results[0]
60+
61+
62+
wmma_f16_16x16x16_f16_ = wmma_f16_16x16x16_f16
63+
64+
65+
def wmma_f16_16x16x16_f16(A, B, C, *, opsel=False, loc=None, ip=None) -> ir.Value:
66+
v16 = ir.VectorType.get((16,), ir.F16Type.get())
67+
return wmma_f16_16x16x16_f16_(
68+
res=v16, A=A, B=B, C=C, opsel=opsel, loc=loc, ip=ip
69+
).result

projects/eudsl-python-extras/mlir/extras/dialects/transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,8 @@ def _structured_bufferize_to_allocation(
375375
memory_space = StringAttr.get(memory_space)
376376

377377
return __structured_bufferize_to_allocation(
378-
allocated_buffer=transform_any_value_t(),
379-
new_ops=transform_any_op_t(),
378+
# allocated_buffer=transform_any_value_t(),
379+
# new_ops=transform_any_op_t(),
380380
target=target,
381381
memory_space=memory_space,
382382
memcpy_op=memcpy_op,

projects/eudsl-python-extras/mlir/extras/testing/testing.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,13 @@ def filecheck(correct: str, module):
7272

7373
correct = "\n".join(filter(None, correct.splitlines()))
7474
correct = dedent(correct)
75-
correct_with_checks = main(correct).replace("CHECK:", "CHECK-NEXT:")
75+
correct_with_checks = main(correct).strip().splitlines()
76+
correct_with_checks = "\n".join(
77+
[
78+
(line.replace("CHECK:", "CHECK-NEXT:") if i > 0 else line)
79+
for i, line in enumerate(correct_with_checks)
80+
]
81+
)
7682

7783
filecheck_path = get_filecheck_path()
7884
with tempfile.NamedTemporaryFile() as tmp:
@@ -81,13 +87,17 @@ def filecheck(correct: str, module):
8187
p = Popen([filecheck_path, tmp.name], stdout=PIPE, stdin=PIPE, stderr=PIPE)
8288
out, err = map(lambda o: o.decode(), p.communicate(input=op.encode()))
8389
if p.returncode:
90+
breakpoint()
91+
if "error: " in err:
92+
raise RuntimeError(err)
8493
diff = list(
8594
difflib.unified_diff(
8695
op.splitlines(), # to this
8796
correct.splitlines(), # delta from this
8897
lineterm="",
8998
)
9099
)
100+
breakpoint()
91101
diff.insert(1, "delta from module to correct")
92102
print("lit report:", err, file=sys.stderr)
93103
raise ValueError("\n" + "\n".join(diff))

0 commit comments

Comments
 (0)