Skip to content

Commit 233b970

Browse files
Add llvm, transform, gpu dialects (#209)
Adds tests and dialects for rocdl, llvm, gpu and transform. Co-authored-by: Maksim Levental <[email protected]>
1 parent 77bd12b commit 233b970

File tree

10 files changed

+3363
-3
lines changed

10 files changed

+3363
-3
lines changed

projects/eudsl-python-extras/mlir/extras/dialects/gpu.py

Lines changed: 602 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
from typing import Union, Optional
5+
6+
from ...util import infer_mlir_type
7+
8+
# noinspection PyUnresolvedReferences
9+
from ....dialects.llvm import *
10+
from ....ir import Type, Value, IntegerAttr, FloatAttr
11+
from ....dialects._ods_common import get_op_result_or_op_results
12+
13+
ValueRef = Value
14+
15+
16+
def llvm_ptr_t():
17+
return Type.parse("!llvm.ptr")
18+
19+
20+
try:
21+
from . import amdgcn
22+
except ImportError:
23+
pass
24+
25+
26+
def mlir_constant(
27+
value: Union[int, float, bool], type: Optional[Type] = None, *, loc=None, ip=None
28+
) -> Value:
29+
if type is None:
30+
type = infer_mlir_type(value, vector=False)
31+
32+
if isinstance(value, int):
33+
value = IntegerAttr.get(type, value)
34+
elif isinstance(value, float):
35+
value = FloatAttr.get(type, value)
36+
else:
37+
raise NotImplementedError(f"{value} is not a valid type")
38+
39+
return get_op_result_or_op_results(
40+
ConstantOp(res=value.type, value=value, loc=loc, ip=ip)
41+
)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
import warnings
5+
6+
# noinspection PyUnresolvedReferences
7+
from ....dialects.llvm import *
8+
from ....ir import Type, F16Type, F32Type, F64Type, BF16Type, IntegerType
9+
10+
try:
11+
from llvm import intrinsic_is_overloaded, intrinsic_get_type, print_type_to_string
12+
from llvm import types_
13+
from llvm.context import context as llvm_context
14+
except ImportError:
15+
warnings.warn(
16+
"llvm bindings not installed; call_intrinsic won't work without supplying return type explicitly"
17+
)
18+
19+
20+
def mlir_type_to_llvm_type(mlir_type, llvm_ctx):
21+
if F16Type.isinstance(mlir_type):
22+
return types_.half_type_in_context(llvm_ctx)
23+
if F32Type.isinstance(mlir_type):
24+
return types_.float_type_in_context(llvm_ctx)
25+
if F64Type.isinstance(mlir_type):
26+
return types_.double_type_in_context(llvm_ctx)
27+
if BF16Type.isinstance(mlir_type):
28+
return types_.b_float_type_in_context(llvm_ctx)
29+
if IntegerType.isinstance(mlir_type):
30+
return types_.int_type_in_context(llvm_ctx, mlir_type.width)
31+
32+
raise NotImplementedError(f"{mlir_type} is not supported")
33+
34+
35+
def llvm_type_str_to_mlir_type(llvm_type: str):
36+
if llvm_type.startswith("<"):
37+
return Type.parse(f"vector{llvm_type}")
38+
if llvm_type == "float":
39+
return F32Type.get()
40+
raise NotImplementedError(f"{llvm_type} is not supported")
41+
42+
43+
_call_intrinsic = call_intrinsic
44+
45+
46+
def call_intrinsic(*args, **kwargs):
47+
intr_id = kwargs.pop("intr_id")
48+
intr_name = kwargs.pop("intr_name")
49+
mlir_ret_type = kwargs.pop("return_type", None)
50+
if mlir_ret_type:
51+
return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])
52+
53+
is_overloaded = kwargs.pop("is_overloaded", None)
54+
if is_overloaded is None:
55+
is_overloaded = intrinsic_is_overloaded(intr_id)
56+
with llvm_context() as ctx:
57+
types = []
58+
if is_overloaded:
59+
types = [mlir_type_to_llvm_type(a.type, ctx.context) for a in args]
60+
intr_decl_fn_ty = intrinsic_get_type(ctx.context, intr_id, types)
61+
62+
ret_type_str = print_type_to_string(intr_decl_fn_ty).split(" (")[0].strip()
63+
mlir_ret_type = None
64+
if ret_type_str:
65+
mlir_ret_type = llvm_type_str_to_mlir_type(ret_type_str)
66+
67+
return _call_intrinsic(mlir_ret_type, intr_name, args, [], [])
68+
69+
70+
call_intrinsic_ = call_intrinsic
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
from ..util import get_user_code_loc
5+
from ... import ir
6+
from ...dialects._ods_common import (
7+
_dispatch_mixed_values,
8+
_cext,
9+
get_op_results_or_values,
10+
get_default_loc_context,
11+
get_op_result_or_op_results,
12+
get_default_loc_context,
13+
segmented_accessor,
14+
)
15+
16+
# noinspection PyUnresolvedReferences
17+
from ...dialects.rocdl import *
18+
19+
_wmma_f16_16x16x16_f16 = wmma_f16_16x16x16_f16
20+
21+
22+
def wmma_f16_16x16x16_f16(A, B, C, *, opsel=False, loc=None, ip=None) -> ir.Value:
23+
v16 = ir.VectorType.get((16,), ir.F16Type.get())
24+
return _wmma_f16_16x16x16_f16(v16, A, B, C, opsel=opsel, loc=loc, ip=ip).result

projects/eudsl-python-extras/mlir/extras/dialects/transform.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,8 +375,6 @@ def _structured_bufferize_to_allocation(
375375
memory_space = StringAttr.get(memory_space)
376376

377377
return __structured_bufferize_to_allocation(
378-
allocated_buffer=transform_any_value_t(),
379-
new_ops=transform_any_op_t(),
380378
target=target,
381379
memory_space=memory_space,
382380
memcpy_op=memcpy_op,

projects/eudsl-python-extras/mlir/extras/testing/testing.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,13 @@ def filecheck(correct: str, module):
7272

7373
correct = "\n".join(filter(None, correct.splitlines()))
7474
correct = dedent(correct)
75-
correct_with_checks = main(correct).replace("CHECK:", "CHECK-NEXT:")
75+
correct_with_checks = main(correct).strip().splitlines()
76+
correct_with_checks = "\n".join(
77+
[
78+
(line.replace("CHECK:", "CHECK-NEXT:") if i > 0 else line)
79+
for i, line in enumerate(correct_with_checks)
80+
]
81+
)
7682

7783
filecheck_path = get_filecheck_path()
7884
with tempfile.NamedTemporaryFile() as tmp:

0 commit comments

Comments
 (0)