|
| 1 | +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 2 | +# See https://llvm.org/LICENSE.txt for license information. |
| 3 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 4 | +import logging |
| 5 | +import os |
| 6 | +import sys |
| 7 | +import tempfile |
| 8 | +from contextlib import ExitStack |
| 9 | +from enum import StrEnum |
| 10 | +from io import StringIO |
| 11 | +from typing import List, Optional, Union |
| 12 | + |
| 13 | +from ..context import disable_multithreading |
| 14 | +from ...ir import Module, StringAttr |
| 15 | +from ...passmanager import PassManager |
| 16 | + |
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
| 20 | +class MlirCompilerError(Exception): |
| 21 | + pass |
| 22 | + |
| 23 | + |
| 24 | +def get_module_name_for_debug_dump(module): |
| 25 | + if "debug_module_name" not in module.operation.attributes: |
| 26 | + return "UnnammedModule" |
| 27 | + return StringAttr(module.operation.attributes["debug_module_name"]).value |
| 28 | + |
| 29 | + |
| 30 | +def run_pipeline( |
| 31 | + module, |
| 32 | + pipeline: Union[str, "Pipeline"], |
| 33 | + description: Optional[str] = None, |
| 34 | + enable_ir_printing=False, |
| 35 | + print_pipeline=False, |
| 36 | + verify=True, |
| 37 | +): |
| 38 | + module = Module.parse(module.operation.get_asm(enable_debug_info=True)) |
| 39 | + |
| 40 | + if isinstance(pipeline, Pipeline): |
| 41 | + pipeline = str(pipeline) |
| 42 | + """Runs `pipeline` on `module`, with a nice repro report if it fails.""" |
| 43 | + module_name = get_module_name_for_debug_dump(module) |
| 44 | + try: |
| 45 | + original_stderr = sys.stderr |
| 46 | + sys.stderr = StringIO() |
| 47 | + # Lower module in place to make it ready for compiler backends. |
| 48 | + with ExitStack() as stack: |
| 49 | + stack.enter_context(module.context) |
| 50 | + asm_for_error_report = module.operation.get_asm( |
| 51 | + large_elements_limit=10, |
| 52 | + enable_debug_info=True, |
| 53 | + ) |
| 54 | + pm = PassManager.parse(pipeline) |
| 55 | + pm.enable_verifier(verify) |
| 56 | + if print_pipeline: |
| 57 | + print(pm) |
| 58 | + if enable_ir_printing: |
| 59 | + stack.enter_context(disable_multithreading()) |
| 60 | + pm.enable_ir_printing() |
| 61 | + |
| 62 | + pm.run(module.operation) |
| 63 | + except Exception as e: |
| 64 | + print(e, file=sys.stderr) |
| 65 | + filename = os.path.join(tempfile.gettempdir(), module_name + ".mlir") |
| 66 | + with open(filename, "w") as f: |
| 67 | + f.write(asm_for_error_report) |
| 68 | + debug_options = "-mlir-print-ir-after-all -mlir-disable-threading" |
| 69 | + description = description or f"{module_name} compile" |
| 70 | + |
| 71 | + message = f"""\ |
| 72 | + {description} failed with the following diagnostics: |
| 73 | +
|
| 74 | + {"*" * 80} |
| 75 | + {sys.stderr.getvalue().strip()} |
| 76 | + {"*" * 80} |
| 77 | +
|
| 78 | + For developers, the error can be reproduced with: |
| 79 | + $ mlir-opt {debug_options} -pass-pipeline='{pipeline}' {filename} |
| 80 | + """ |
| 81 | + trimmed_message = "\n".join([m.lstrip() for m in message.split("\n")]) |
| 82 | + raise MlirCompilerError(trimmed_message) |
| 83 | + finally: |
| 84 | + sys.stderr = original_stderr |
| 85 | + |
| 86 | + return module |
| 87 | + |
| 88 | + |
| 89 | +class Pipeline: |
| 90 | + _pipeline: List[str] = [] |
| 91 | + |
| 92 | + def __init__(self, pipeline=None, wrapper=None): |
| 93 | + if pipeline is None: |
| 94 | + pipeline = [] |
| 95 | + self._pipeline = pipeline |
| 96 | + |
| 97 | + def Nested(self, context, p: "Pipeline"): |
| 98 | + self._pipeline.append(f"{context}({p.materialize(module=False)})") |
| 99 | + return self |
| 100 | + |
| 101 | + def Func(self, p: "Pipeline"): |
| 102 | + return self.Nested("func.func", p) |
| 103 | + |
| 104 | + def Spirv(self, p: "Pipeline"): |
| 105 | + return self.Nested("spirv.module", p) |
| 106 | + |
| 107 | + def Gpu(self, p: "Pipeline"): |
| 108 | + assert isinstance(p, Pipeline) |
| 109 | + return self.Nested("gpu.module", p) |
| 110 | + |
| 111 | + def materialize(self, module=True): |
| 112 | + pipeline_str = ",".join(self._pipeline) |
| 113 | + if module: |
| 114 | + pipeline_str = f"builtin.module({pipeline_str})" |
| 115 | + logger.debug(f"{pipeline_str}") |
| 116 | + return pipeline_str |
| 117 | + |
| 118 | + def __str__(self): |
| 119 | + return self.materialize() |
| 120 | + |
| 121 | + def __iadd__(self, other: "Pipeline"): |
| 122 | + self._pipeline.extend(other._pipeline) |
| 123 | + return self |
| 124 | + |
| 125 | + def __add__(self, other: "Pipeline"): |
| 126 | + return Pipeline(self._pipeline + other._pipeline) |
| 127 | + |
| 128 | + def add_pass(self, pass_name, **kwargs): |
| 129 | + kwargs = { |
| 130 | + k: int(v) if isinstance(v, bool) else v |
| 131 | + for k, v in kwargs.items() |
| 132 | + if v is not None |
| 133 | + } |
| 134 | + if kwargs: |
| 135 | + args_str = " ".join(f"{k}={v}" for k, v in kwargs.items()) |
| 136 | + pass_str = f"{pass_name}{{ {args_str} }}" |
| 137 | + else: |
| 138 | + pass_str = f"{pass_name}" |
| 139 | + self._pipeline.append(pass_str) |
| 140 | + return self |
| 141 | + |
| 142 | + def lower_to_llvm(self, use_bare_ptr_memref_call_conv=False): |
| 143 | + # https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44 |
| 144 | + return ( |
| 145 | + self.Func( |
| 146 | + self.__class__() |
| 147 | + # Blanket-convert any remaining high-level vector ops to loops if any remain. |
| 148 | + .convert_vector_to_scf() |
| 149 | + # Blanket-convert any remaining linalg ops to loops if any remain. |
| 150 | + .convert_linalg_to_loops() |
| 151 | + ) |
| 152 | + # Blanket-convert any remaining affine ops if any remain. |
| 153 | + .lower_affine() |
| 154 | + # Convert SCF to CF (always needed). |
| 155 | + .convert_scf_to_cf() |
| 156 | + # Sprinkle some cleanups. |
| 157 | + .canonicalize() |
| 158 | + .cse() |
| 159 | + # Convert vector to LLVM (always needed). |
| 160 | + .convert_vector_to_llvm(force_32bit_vector_indices=True) |
| 161 | + # Convert Math to LLVM (always needed). |
| 162 | + .Func(self.__class__().convert_math_to_llvm()) |
| 163 | + # Expand complicated MemRef operations before lowering them. |
| 164 | + .expand_strided_metadata() |
| 165 | + # The expansion may create affine expressions. Get rid of them. |
| 166 | + .lower_affine() |
| 167 | + # Convert MemRef to LLVM (always needed). |
| 168 | + .finalize_memref_to_llvm() |
| 169 | + # Convert Func to LLVM (always needed). |
| 170 | + .convert_func_to_llvm( |
| 171 | + use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv |
| 172 | + ) |
| 173 | + .convert_arith_to_llvm() |
| 174 | + .convert_cf_to_llvm() |
| 175 | + # Convert Index to LLVM (always needed). |
| 176 | + .convert_index_to_llvm() |
| 177 | + # Convert UB to LLVM (always needed). |
| 178 | + .convert_ub_to_llvm() |
| 179 | + # Convert remaining unrealized_casts (always needed). |
| 180 | + .reconcile_unrealized_casts() |
| 181 | + ) |
| 182 | + |
| 183 | + def bufferize(self): |
| 184 | + return ( |
| 185 | + self.Func(self.__class__().empty_tensor_to_alloc_tensor()) |
| 186 | + .one_shot_bufferize() |
| 187 | + .Func(self.__class__().buffer_deallocation_simplification()) |
| 188 | + ) |
| 189 | + |
| 190 | + def lower_to_openmp(self): |
| 191 | + return self.convert_scf_to_openmp().Func(self.__class__().lower_affine()) |
| 192 | + |
| 193 | + def sparse_compiler( |
| 194 | + self, |
| 195 | + parallelization_strategy=None, |
| 196 | + enable_runtime_library=None, |
| 197 | + enable_buffer_initialization=None, |
| 198 | + vl=None, |
| 199 | + s2s_strategy=None, |
| 200 | + reassociate_fp_reductions=None, |
| 201 | + enable_index_optimizations=None, |
| 202 | + enable_amx=None, |
| 203 | + enable_arm_neon=None, |
| 204 | + enable_arm_sve=None, |
| 205 | + enable_x86vector=None, |
| 206 | + ): |
| 207 | + self.add_pass( |
| 208 | + "sparse-compiler", |
| 209 | + parallelization_strategy=parallelization_strategy, |
| 210 | + enable_runtime_library=enable_runtime_library, |
| 211 | + enable_buffer_initialization=enable_buffer_initialization, |
| 212 | + vl=vl, |
| 213 | + s2s_strategy=s2s_strategy, |
| 214 | + reassociate_fp_reductions=reassociate_fp_reductions, |
| 215 | + enable_index_optimizations=enable_index_optimizations, |
| 216 | + enable_amx=enable_amx, |
| 217 | + enable_arm_neon=enable_arm_neon, |
| 218 | + enable_arm_sve=enable_arm_sve, |
| 219 | + enable_x86vector=enable_x86vector, |
| 220 | + ) |
| 221 | + return self |
| 222 | + |
| 223 | + def lower_to_vulkan(self, index_bitwidth=None): |
| 224 | + return ( |
| 225 | + self.gpu_kernel_outlining() |
| 226 | + .fold_memref_alias_ops() |
| 227 | + .convert_gpu_to_spirv() |
| 228 | + .Spirv(self.__class__().spirv_lower_abi_attrs().spirv_update_vce()) |
| 229 | + .convert_gpu_launch_to_vulkan_launch() |
| 230 | + .finalize_memref_to_llvm() |
| 231 | + .Func(self.__class__().llvm_request_c_wrappers()) |
| 232 | + .convert_func_to_llvm(index_bitwidth=index_bitwidth) |
| 233 | + .reconcile_unrealized_casts() |
| 234 | + .launch_func_to_vulkan() |
| 235 | + ) |
| 236 | + |
| 237 | + |
| 238 | +class GreedySimplifyRegionLevel(StrEnum): |
| 239 | + DISABLED = "disabled" |
| 240 | + NORMAL = "normal" |
| 241 | + AGGRESSIVE = "aggressive" |
| 242 | + |
| 243 | + |
| 244 | +class SparseParallelizationStrategy(StrEnum): |
| 245 | + NONE = "none" |
| 246 | + DENSE_OUTER_LOOP = "dense-outer-loop" |
| 247 | + ANY_STORAGE_OUTER_LOOP = "any-storage-outer-loop" |
| 248 | + DENSE_ANY_LOOP = "dense-any-loop" |
| 249 | + ANY_STORAGE_ANY_LOOP = "any-storage-any-loop" |
0 commit comments