Skip to content

Commit 4f59995

Browse files
authored
[eudsl-python-extras] refactor pass Pipeline and generate_pass_pipeline script (#219)
This PR refactors `mlir/extras/runtime/passes.py` into two files - a base (`_passes_base.py`) which has handwritten code and a wholly generated `passes.py`. This enables us to quickly/easily generate the passes from TD without copy-pasting. The PR also regenerates the passes (so `Pipeline` is up to date) and also adds enums for pass options `SparseParallelizationStrategy` and `GreedySimplifyRegionLevel` (not for any special reason, just to demonstrate how to do this going forward). Also cleaned up how pass option strs are mangled (i.e., they're no longer mangled and unmangled in `add_pass`).
1 parent a6ed1ae commit 4f59995

File tree

3 files changed

+1822
-820
lines changed

3 files changed

+1822
-820
lines changed
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
import logging
5+
import os
6+
import sys
7+
import tempfile
8+
from contextlib import ExitStack
9+
from enum import StrEnum
10+
from io import StringIO
11+
from typing import List, Optional, Union
12+
13+
from ..context import disable_multithreading
14+
from ...ir import Module, StringAttr
15+
from ...passmanager import PassManager
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class MlirCompilerError(Exception):
21+
pass
22+
23+
24+
def get_module_name_for_debug_dump(module):
25+
if "debug_module_name" not in module.operation.attributes:
26+
return "UnnammedModule"
27+
return StringAttr(module.operation.attributes["debug_module_name"]).value
28+
29+
30+
def run_pipeline(
31+
module,
32+
pipeline: Union[str, "Pipeline"],
33+
description: Optional[str] = None,
34+
enable_ir_printing=False,
35+
print_pipeline=False,
36+
verify=True,
37+
):
38+
module = Module.parse(module.operation.get_asm(enable_debug_info=True))
39+
40+
if isinstance(pipeline, Pipeline):
41+
pipeline = str(pipeline)
42+
"""Runs `pipeline` on `module`, with a nice repro report if it fails."""
43+
module_name = get_module_name_for_debug_dump(module)
44+
try:
45+
original_stderr = sys.stderr
46+
sys.stderr = StringIO()
47+
# Lower module in place to make it ready for compiler backends.
48+
with ExitStack() as stack:
49+
stack.enter_context(module.context)
50+
asm_for_error_report = module.operation.get_asm(
51+
large_elements_limit=10,
52+
enable_debug_info=True,
53+
)
54+
pm = PassManager.parse(pipeline)
55+
pm.enable_verifier(verify)
56+
if print_pipeline:
57+
print(pm)
58+
if enable_ir_printing:
59+
stack.enter_context(disable_multithreading())
60+
pm.enable_ir_printing()
61+
62+
pm.run(module.operation)
63+
except Exception as e:
64+
print(e, file=sys.stderr)
65+
filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir")
66+
with open(filename, "w") as f:
67+
f.write(asm_for_error_report)
68+
debug_options = "-mlir-print-ir-after-all -mlir-disable-threading"
69+
description = description or f"{module_name} compile"
70+
71+
message = f"""\
72+
{description} failed with the following diagnostics:
73+
74+
{"*" * 80}
75+
{sys.stderr.getvalue().strip()}
76+
{"*" * 80}
77+
78+
For developers, the error can be reproduced with:
79+
$ mlir-opt {debug_options} -pass-pipeline='{pipeline}' {filename}
80+
"""
81+
trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")])
82+
raise MlirCompilerError(trimmed_message)
83+
finally:
84+
sys.stderr = original_stderr
85+
86+
return module
87+
88+
89+
class Pipeline:
90+
_pipeline: List[str] = []
91+
92+
def __init__(self, pipeline=None, wrapper=None):
93+
if pipeline is None:
94+
pipeline = []
95+
self._pipeline = pipeline
96+
97+
def Nested(self, context, p: "Pipeline"):
98+
self._pipeline.append(f"{context}({p.materialize(module=False)})")
99+
return self
100+
101+
def Func(self, p: "Pipeline"):
102+
return self.Nested("func.func", p)
103+
104+
def Spirv(self, p: "Pipeline"):
105+
return self.Nested("spirv.module", p)
106+
107+
def Gpu(self, p: "Pipeline"):
108+
assert isinstance(p, Pipeline)
109+
return self.Nested("gpu.module", p)
110+
111+
def materialize(self, module=True):
112+
pipeline_str = ",".join(self._pipeline)
113+
if module:
114+
pipeline_str = f"builtin.module({pipeline_str})"
115+
logger.debug(f"{pipeline_str}")
116+
return pipeline_str
117+
118+
def __str__(self):
119+
return self.materialize()
120+
121+
def __iadd__(self, other: "Pipeline"):
122+
self._pipeline.extend(other._pipeline)
123+
return self
124+
125+
def __add__(self, other: "Pipeline"):
126+
return Pipeline(self._pipeline + other._pipeline)
127+
128+
def add_pass(self, pass_name, **kwargs):
129+
kwargs = {
130+
k: int(v) if isinstance(v, bool) else v
131+
for k, v in kwargs.items()
132+
if v is not None
133+
}
134+
if kwargs:
135+
args_str = " ".join(f"{k}={v}" for k, v in kwargs.items())
136+
pass_str = f"{pass_name}{{ {args_str} }}"
137+
else:
138+
pass_str = f"{pass_name}"
139+
self._pipeline.append(pass_str)
140+
return self
141+
142+
def lower_to_llvm(self, use_bare_ptr_memref_call_conv=False):
143+
# https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44
144+
return (
145+
self.Func(
146+
self.__class__()
147+
# Blanket-convert any remaining high-level vector ops to loops if any remain.
148+
.convert_vector_to_scf()
149+
# Blanket-convert any remaining linalg ops to loops if any remain.
150+
.convert_linalg_to_loops()
151+
)
152+
# Blanket-convert any remaining affine ops if any remain.
153+
.lower_affine()
154+
# Convert SCF to CF (always needed).
155+
.convert_scf_to_cf()
156+
# Sprinkle some cleanups.
157+
.canonicalize()
158+
.cse()
159+
# Convert vector to LLVM (always needed).
160+
.convert_vector_to_llvm(force_32bit_vector_indices=True)
161+
# Convert Math to LLVM (always needed).
162+
.Func(self.__class__().convert_math_to_llvm())
163+
# Expand complicated MemRef operations before lowering them.
164+
.expand_strided_metadata()
165+
# The expansion may create affine expressions. Get rid of them.
166+
.lower_affine()
167+
# Convert MemRef to LLVM (always needed).
168+
.finalize_memref_to_llvm()
169+
# Convert Func to LLVM (always needed).
170+
.convert_func_to_llvm(
171+
use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv
172+
)
173+
.convert_arith_to_llvm()
174+
.convert_cf_to_llvm()
175+
# Convert Index to LLVM (always needed).
176+
.convert_index_to_llvm()
177+
# Convert UB to LLVM (always needed).
178+
.convert_ub_to_llvm()
179+
# Convert remaining unrealized_casts (always needed).
180+
.reconcile_unrealized_casts()
181+
)
182+
183+
def bufferize(self):
184+
return (
185+
self.Func(self.__class__().empty_tensor_to_alloc_tensor())
186+
.one_shot_bufferize()
187+
.Func(self.__class__().buffer_deallocation_simplification())
188+
)
189+
190+
def lower_to_openmp(self):
191+
return self.convert_scf_to_openmp().Func(self.__class__().lower_affine())
192+
193+
def sparse_compiler(
194+
self,
195+
parallelization_strategy=None,
196+
enable_runtime_library=None,
197+
enable_buffer_initialization=None,
198+
vl=None,
199+
s2s_strategy=None,
200+
reassociate_fp_reductions=None,
201+
enable_index_optimizations=None,
202+
enable_amx=None,
203+
enable_arm_neon=None,
204+
enable_arm_sve=None,
205+
enable_x86vector=None,
206+
):
207+
self.add_pass(
208+
"sparse-compiler",
209+
parallelization_strategy=parallelization_strategy,
210+
enable_runtime_library=enable_runtime_library,
211+
enable_buffer_initialization=enable_buffer_initialization,
212+
vl=vl,
213+
s2s_strategy=s2s_strategy,
214+
reassociate_fp_reductions=reassociate_fp_reductions,
215+
enable_index_optimizations=enable_index_optimizations,
216+
enable_amx=enable_amx,
217+
enable_arm_neon=enable_arm_neon,
218+
enable_arm_sve=enable_arm_sve,
219+
enable_x86vector=enable_x86vector,
220+
)
221+
return self
222+
223+
def lower_to_vulkan(self, index_bitwidth=None):
224+
return (
225+
self.gpu_kernel_outlining()
226+
.fold_memref_alias_ops()
227+
.convert_gpu_to_spirv()
228+
.Spirv(self.__class__().spirv_lower_abi_attrs().spirv_update_vce())
229+
.convert_gpu_launch_to_vulkan_launch()
230+
.finalize_memref_to_llvm()
231+
.Func(self.__class__().llvm_request_c_wrappers())
232+
.convert_func_to_llvm(index_bitwidth=index_bitwidth)
233+
.reconcile_unrealized_casts()
234+
.launch_func_to_vulkan()
235+
)
236+
237+
238+
class GreedySimplifyRegionLevel(StrEnum):
239+
DISABLED = "disabled"
240+
NORMAL = "normal"
241+
AGGRESSIVE = "aggressive"
242+
243+
244+
class SparseParallelizationStrategy(StrEnum):
245+
NONE = "none"
246+
DENSE_OUTER_LOOP = "dense-outer-loop"
247+
ANY_STORAGE_OUTER_LOOP = "any-storage-outer-loop"
248+
DENSE_ANY_LOOP = "dense-any-loop"
249+
ANY_STORAGE_ANY_LOOP = "any-storage-any-loop"

0 commit comments

Comments
 (0)