Skip to content

Commit 5f3ac07

Browse files
Add utility scripts (#214)
* Add utility scripts Adds the script to generate the pass pipeline class and to lift mlir to the Python bindings equivalent. * Fix imports
1 parent 16cea74 commit 5f3ac07

File tree

4 files changed

+664
-1
lines changed

4 files changed

+664
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
*.lo
77
*.o
88
*.obj
9+
**/__pycache__
910

1011
# Precompiled Headers
1112
*.gch

projects/eudsl-python-extras/.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
15
# Byte-compiled / optimized / DLL files
2-
__pycache__/
36
*.py[cod]
47
*$py.class
58

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
import glob
5+
import json
6+
import keyword
7+
import platform
8+
import shutil
9+
import subprocess
10+
import sys
11+
from dataclasses import dataclass
12+
from pathlib import Path
13+
from subprocess import PIPE
14+
from textwrap import dedent, indent
15+
16+
from mlir._mlir_libs import include
17+
18+
include_path = Path(include.__path__[0])
19+
20+
21+
def dump_json(td_path: Path):
22+
llvm_tblgen_name = "llvm-tblgen"
23+
if platform.system() == "Windows":
24+
llvm_tblgen_name += ".exe"
25+
26+
# try from mlir-native-tools
27+
llvm_tblgen_path = Path(sys.prefix) / "bin" / llvm_tblgen_name
28+
# try to find using which
29+
if not llvm_tblgen_path.exists():
30+
llvm_tblgen_path = shutil.which(llvm_tblgen_name)
31+
assert Path(llvm_tblgen_path).exists() is not None, "couldn't find llvm-tblgen"
32+
33+
args = [f"-I{include_path}", f"-I{td_path.parent}", str(td_path), "-dump-json"]
34+
res = subprocess.run(
35+
[llvm_tblgen_path] + args,
36+
cwd=Path(".").cwd(),
37+
check=True,
38+
stdout=PIPE,
39+
stderr=subprocess.DEVNULL,
40+
)
41+
res_json = json.loads(res.stdout.decode("utf-8"))
42+
43+
return res_json
44+
45+
46+
@dataclass
47+
class Option:
48+
argument: str
49+
description: str
50+
type: str
51+
additional_opt_flags: str
52+
default_value: str
53+
list_option: bool = False
54+
55+
56+
@dataclass
57+
class Pass:
58+
name: str
59+
argument: str
60+
options: list[Option]
61+
description: str
62+
summary: str
63+
64+
65+
TYPE_MAP = {
66+
"::mlir::gpu::amd::Runtime": '"gpu::amd::Runtime"',
67+
"OpPassManager": '"OpPassManager"',
68+
"bool": "bool",
69+
"double": "float",
70+
"enum FusionMode": '"FusionMode"',
71+
"int": "int",
72+
"int32_t": "int",
73+
"int64_t": "int",
74+
"mlir::SparseParallelizationStrategy": '"SparseParallelizationStrategy"',
75+
"mlir::arm_sme::ArmStreaming": '"arm_sme::ArmStreaming"',
76+
"std::string": "str",
77+
"uint64_t": "int",
78+
"unsigned": "int",
79+
}
80+
81+
82+
def generate_pass_method(pass_: Pass):
83+
ident = 4
84+
py_args = []
85+
for o in pass_.options:
86+
argument = o.argument.replace("-", "_")
87+
if keyword.iskeyword(argument):
88+
argument += "_"
89+
type = TYPE_MAP.get(o.type, f"'{o.type}'")
90+
if o.list_option:
91+
type = f"List[{type}]"
92+
py_args.append((argument, type))
93+
94+
def print_options_doc_string(pass_):
95+
print(
96+
indent(
97+
f"'''{pass_.summary}",
98+
prefix=" " * ident * 2,
99+
)
100+
)
101+
if pass_.description:
102+
for l in pass_.description.split("\n"):
103+
print(
104+
indent(
105+
f"{l}",
106+
prefix=" " * ident,
107+
)
108+
)
109+
if pass_.options:
110+
print(
111+
indent(
112+
f"Args:",
113+
prefix=" " * ident * 2,
114+
)
115+
)
116+
for o in pass_.options:
117+
print(
118+
indent(
119+
f"{o.argument}: {o.description}",
120+
prefix=" " * ident * 3,
121+
)
122+
)
123+
print(
124+
indent(
125+
f"'''",
126+
prefix=" " * ident * 2,
127+
)
128+
)
129+
130+
pass_name = pass_.argument
131+
if py_args:
132+
py_args_str = ", ".join([f"{n}: {t} = None" for n, t in py_args])
133+
print(
134+
indent(
135+
f"def {pass_name.replace('-', '_')}(self, {py_args_str}):",
136+
prefix=" " * ident,
137+
)
138+
)
139+
print_options_doc_string(pass_)
140+
141+
mlir_args = []
142+
for n, t in py_args:
143+
if "list" in t:
144+
print(
145+
indent(
146+
f"if {n} is not None and isinstance({n}, (list, tuple)):",
147+
prefix=" " * ident * 2,
148+
)
149+
)
150+
print(indent(f"{n} = ','.join(map(str, {n}))", prefix=" " * ident * 3))
151+
mlir_args.append(f"{n}={n}")
152+
print(
153+
indent(
154+
dedent(
155+
f"""\
156+
self.add_pass("{pass_name}", {", ".join(mlir_args)})
157+
return self
158+
"""
159+
),
160+
prefix=" " * ident * 2,
161+
)
162+
)
163+
164+
else:
165+
print(
166+
indent(
167+
dedent(
168+
f"""\
169+
def {pass_name.replace("-", "_")}(self):"""
170+
),
171+
prefix=" " * ident,
172+
)
173+
)
174+
print_options_doc_string(pass_)
175+
print(
176+
indent(
177+
dedent(
178+
f"""\
179+
self.add_pass("{pass_name}")
180+
return self
181+
"""
182+
),
183+
prefix=" " * ident * 2,
184+
)
185+
)
186+
187+
188+
def gather_passes_from_td_json(j):
189+
passes = []
190+
for pass_ in j["!instanceof"]["Pass"] + j["!instanceof"]["InterfacePass"]:
191+
pass_ = j[pass_]
192+
options = []
193+
for o in pass_["options"]:
194+
option = j[o["def"]]
195+
option = Option(
196+
argument=option["argument"],
197+
description=option["description"],
198+
type=option["type"],
199+
additional_opt_flags=option["additionalOptFlags"],
200+
default_value=option["defaultValue"],
201+
list_option="ListOption" in option["!superclasses"],
202+
)
203+
options.append(option)
204+
pass_ = Pass(
205+
name=pass_["!name"],
206+
argument=pass_["argument"],
207+
options=options,
208+
description=pass_["description"],
209+
summary=pass_["summary"],
210+
)
211+
passes.append(pass_)
212+
213+
return passes
214+
215+
216+
if __name__ == "__main__":
217+
passes = []
218+
for td in glob.glob(str(include_path / "**" / "*.td"), recursive=True):
219+
try:
220+
j = dump_json(Path(td))
221+
if j["!instanceof"]["Pass"]:
222+
passes.extend(gather_passes_from_td_json(j))
223+
except Exception as e:
224+
print(f"Error parsing {td}: {e}")
225+
continue
226+
227+
for p in sorted(passes, key=lambda p: p.argument):
228+
generate_pass_method(p)

0 commit comments

Comments
 (0)