Skip to content

Commit 16cea74

Browse files
Add cf dialect, region tests, and nvgpu tests (#213)
* Add test for misc region op and cf dialect * Add nvvm/nvgpu tests
1 parent 3457736 commit 16cea74

File tree

5 files changed

+1579
-1
lines changed

5 files changed

+1579
-1
lines changed
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
from typing import List, Union
5+
6+
from ..util import Successor
7+
from ...dialects._cf_ops_gen import _Dialect
8+
from ...dialects._ods_common import (
9+
_cext,
10+
)
11+
from ...dialects.cf import *
12+
from ...ir import Block, InsertionPoint, Value
13+
14+
15+
@_cext.register_operation(_Dialect, replace=True)
16+
class CondBranchOp(CondBranchOp):
17+
@property
18+
def true(self):
19+
return Successor(self, self.trueDestOperands, self.successors[0], 0)
20+
21+
@property
22+
def false(self):
23+
return Successor(self, self.falseDestOperands, self.successors[1], 1)
24+
25+
26+
def br(
27+
dest: Union[Value, Block] = None, *dest_operands: List[Value], loc=None, ip=None
28+
):
29+
if isinstance(dest, Value):
30+
dest_operands = [dest] + list(dest_operands)
31+
dest = None
32+
if dest is None:
33+
dest = InsertionPoint.current.block
34+
return BranchOp(dest_operands, dest, loc=loc, ip=ip)
35+
36+
37+
def cond_br(
38+
condition: Value,
39+
true_dest: Union[Value, Block] = None,
40+
false_dest: Union[Value, Block] = None,
41+
true_dest_operands: List[Value] = None,
42+
false_dest_operands: List[Value] = None,
43+
*,
44+
loc=None,
45+
ip=None,
46+
):
47+
if true_dest is None:
48+
true_dest = InsertionPoint.current.block
49+
if false_dest is None:
50+
false_dest = InsertionPoint.current.block
51+
if true_dest_operands is None:
52+
true_dest_operands = []
53+
if false_dest_operands is None:
54+
false_dest_operands = []
55+
return CondBranchOp(
56+
condition,
57+
true_dest_operands,
58+
false_dest_operands,
59+
true_dest,
60+
false_dest,
61+
loc=loc,
62+
ip=ip,
63+
)
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
from .gpu import smem_space
5+
from . import arith
6+
from ...dialects.nvgpu import *
7+
from ...ir import Attribute, Type
8+
from .. import types as T
9+
10+
11+
def nvgpu_type(mnemonic, attr_value):
12+
return Type.parse(f"!nvgpu.{mnemonic}<{attr_value}>")
13+
14+
15+
def barrier_group_t(num_barriers=1, address_space=None):
16+
if address_space is None:
17+
address_space = smem_space()
18+
return nvgpu_type(
19+
"mbarrier.group", f"memorySpace={address_space}, num_barriers = {num_barriers}"
20+
)
21+
22+
23+
def warpgroup_accumulator_t(M, N, dtype):
24+
return nvgpu_type("warpgroup.accumulator", f"fragmented=vector<{M}x{N}x{dtype}>")
25+
26+
27+
def warpgroup_descriptor(M, N, dtype):
28+
return nvgpu_type(
29+
"warpgroup.descriptor",
30+
f"tensor=memref<{M}x{N}x{dtype}, {smem_space()}>",
31+
)
32+
33+
34+
_mbarrier_init = mbarrier_init
35+
36+
37+
_mbarrier_create = mbarrier_create
38+
39+
40+
def mbarrier_create(num_barriers=1, address_space=None, *, loc=None, ip=None):
41+
return _mbarrier_create(
42+
barriers=barrier_group_t(num_barriers, address_space), loc=loc, ip=ip
43+
)
44+
45+
46+
def mbarrier_init(barriers, count, mbar_id, *, predicate=None, loc=None, ip=None):
47+
if isinstance(count, int):
48+
count = arith.constant(count, index=True)
49+
if isinstance(mbar_id, int):
50+
mbar_id = arith.constant(mbar_id, index=True)
51+
return _mbarrier_init(
52+
barriers=barriers,
53+
count=count,
54+
mbar_id=mbar_id,
55+
predicate=predicate,
56+
loc=loc,
57+
ip=ip,
58+
)
59+
60+
61+
_mbarrier_arrive_expect_tx = mbarrier_arrive_expect_tx
62+
63+
64+
def mbarrier_arrive_expect_tx(
65+
barriers, txcount, mbar_id, *, predicate=None, loc=None, ip=None
66+
):
67+
if isinstance(txcount, int):
68+
txcount = arith.constant(txcount, index=True)
69+
if isinstance(mbar_id, int):
70+
mbar_id = arith.constant(mbar_id, index=True)
71+
return _mbarrier_arrive_expect_tx(
72+
barriers=barriers,
73+
txcount=txcount,
74+
mbar_id=mbar_id,
75+
predicate=predicate,
76+
loc=loc,
77+
ip=ip,
78+
)
79+
80+
81+
_tma_async_load = tma_async_load
82+
83+
84+
def tma_async_load(
85+
dst,
86+
barriers,
87+
tensor_map_descriptor,
88+
coordinates,
89+
mbar_id,
90+
*,
91+
multicast_mask=None,
92+
predicate=None,
93+
loc=None,
94+
ip=None,
95+
):
96+
for i, c in enumerate(coordinates):
97+
if isinstance(c, int):
98+
coordinates[i] = arith.constant(c, index=True)
99+
100+
if isinstance(mbar_id, int):
101+
mbar_id = arith.constant(mbar_id, index=True)
102+
103+
return _tma_async_load(
104+
dst=dst,
105+
barriers=barriers,
106+
tensor_map_descriptor=tensor_map_descriptor,
107+
coordinates=coordinates,
108+
mbar_id=mbar_id,
109+
multicast_mask=multicast_mask,
110+
predicate=predicate,
111+
loc=loc,
112+
ip=ip,
113+
)
114+
115+
116+
_mbarrier_try_wait_parity = mbarrier_try_wait_parity
117+
118+
119+
def mbarrier_try_wait_parity(
120+
barriers, mbar_id, phase_parity=False, ticks=10000000, *, loc=None, ip=None
121+
):
122+
if isinstance(ticks, int):
123+
ticks = arith.constant(ticks, index=True)
124+
if isinstance(mbar_id, int):
125+
mbar_id = arith.constant(mbar_id, index=True)
126+
if isinstance(phase_parity, bool):
127+
phase_parity = arith.constant(phase_parity, type=T.bool())
128+
return _mbarrier_try_wait_parity(
129+
barriers=barriers,
130+
phase_parity=phase_parity,
131+
ticks=ticks,
132+
mbar_id=mbar_id,
133+
loc=loc,
134+
ip=ip,
135+
)
136+
137+
138+
_warpgroup_mma = warpgroup_mma
139+
140+
141+
def warpgroup_mma(
142+
matrix_c,
143+
descriptor_a,
144+
descriptor_b,
145+
*,
146+
wait_group=None,
147+
transpose_a=None,
148+
transpose_b=None,
149+
loc=None,
150+
ip=None,
151+
):
152+
matrix_d = matrix_c.type
153+
return _warpgroup_mma(
154+
matrix_d=matrix_d,
155+
descriptor_a=descriptor_a,
156+
descriptor_b=descriptor_b,
157+
matrix_c=matrix_c,
158+
wait_group=wait_group,
159+
transpose_a=transpose_a,
160+
transpose_b=transpose_b,
161+
loc=loc,
162+
ip=ip,
163+
)

projects/eudsl-python-extras/mlir/extras/runtime/refbackend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,10 +223,11 @@ def wrapper(*args, **_kwargs):
223223
class LLVMJITBackend:
224224
def __init__(
225225
self,
226-
shared_lib_paths=None,
226+
shared_lib_paths: set[str | Path] | None = None,
227227
):
228228
if shared_lib_paths is None:
229229
shared_lib_paths = set()
230+
shared_lib_paths = set(shared_lib_paths)
230231
if platform.system() != "Windows":
231232
shared_lib_paths |= set(_exec_engine_shared_libs)
232233
self.shared_lib_paths = list(shared_lib_paths)

0 commit comments

Comments
 (0)