diff --git a/.github/workflows/build_mlir_python_bindings_wheel.yml b/.github/workflows/build_mlir_python_bindings_wheel.yml index 16fe30a6..0e42ed48 100644 --- a/.github/workflows/build_mlir_python_bindings_wheel.yml +++ b/.github/workflows/build_mlir_python_bindings_wheel.yml @@ -206,7 +206,7 @@ jobs: "windows-2022" ] python-version: [ - # "3.10", "3.11", "3.12", + "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t" ] include: [ @@ -230,6 +230,21 @@ jobs: - runs-on: macos-13 python-version: "3.14t" + - runs-on: macos-14 + python-version: "3.10" + + - runs-on: macos-14 + python-version: "3.11" + + - runs-on: macos-14 + python-version: "3.12" + + - runs-on: macos-14 + python-version: "3.13" + + - runs-on: macos-14 + python-version: "3.14" + runs-on: ${{ matrix.runs-on }} name: "Test mlir-python-bindings ${{ matrix.name }} ${{ matrix.python-version }}" diff --git a/.github/workflows/build_test_release_eudsl.yml b/.github/workflows/build_test_release_eudsl.yml index ff26c344..bb8bb808 100644 --- a/.github/workflows/build_test_release_eudsl.yml +++ b/.github/workflows/build_test_release_eudsl.yml @@ -277,8 +277,8 @@ jobs: # "macos-13", "macos-14", "windows-2022"] python-version: [ - # "3.9", "3.10", "3.11", - "3.12", "3.13" + "3.10", "3.11", "3.12", + "3.13", "3.14", "3.14t" ] include: [ {runs-on: "ubuntu-22.04", name: "ubuntu_x86_64", os: "ubuntu"}, @@ -291,6 +291,21 @@ jobs: - runs-on: macos-13 python-version: "3.13" + - runs-on: macos-14 + python-version: "3.10" + + - runs-on: macos-14 + python-version: "3.11" + + - runs-on: macos-14 + python-version: "3.12" + + - runs-on: macos-14 + python-version: "3.13" + + - runs-on: macos-14 + python-version: "3.14" + runs-on: ${{ matrix.runs-on }} name: "Test tblgen ${{ matrix.name }} ${{ matrix.python-version }}" @@ -306,7 +321,7 @@ jobs: submodules: false - name: "Install Python" - uses: actions/setup-python@v4 + uses: actions/setup-python@v6.0.0 with: python-version: "${{ matrix.python-version }}" @@ -348,8 +363,11 @@ jobs: "windows-2022" ] python-version: [ + # we only build 3.12 and up # "3.10", "3.11", - "3.12", "3.13" + "3.12", "3.13", "3.14", + # stable abi doesn't support 3.14t? + # "3.14t" ] include: [ {runs-on: "ubuntu-22.04", name: "ubuntu_x86_64", os: "ubuntu"}, @@ -362,6 +380,18 @@ jobs: - runs-on: macos-13 python-version: "3.13" + - runs-on: macos-14 + python-version: "3.10" + + - runs-on: macos-14 + python-version: "3.11" + + - runs-on: macos-14 + python-version: "3.12" + + - runs-on: macos-14 + python-version: "3.13" + runs-on: ${{ matrix.runs-on }} name: "Test llvmpy ${{ matrix.name }} ${{ matrix.python-version }}" @@ -377,7 +407,7 @@ jobs: submodules: false - name: "Install Python" - uses: actions/setup-python@v4 + uses: actions/setup-python@v6.0.0 with: python-version: "${{ matrix.python-version }}" @@ -445,7 +475,7 @@ jobs: if: (github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch' - needs: [test-eudsl-llvmpy, test-eudsl-tblgen] + needs: [release-eudsl] permissions: contents: read diff --git a/.github/workflows/build_test_release_eudsl_python_extras.yml b/.github/workflows/build_test_release_eudsl_python_extras.yml index facaddb2..b989b8d3 100644 --- a/.github/workflows/build_test_release_eudsl_python_extras.yml +++ b/.github/workflows/build_test_release_eudsl_python_extras.yml @@ -94,7 +94,7 @@ jobs: "windows-2022" ] python-version: [ - # "3.10", "3.11", "3.12", + "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t" ] include: [ @@ -118,6 +118,21 @@ jobs: - runs-on: macos-13 python-version: "3.14t" + - runs-on: macos-14 + python-version: "3.10" + + - runs-on: macos-14 + python-version: "3.11" + + - runs-on: macos-14 + python-version: "3.12" + + - runs-on: macos-14 + python-version: "3.13" + + - runs-on: macos-14 + python-version: "3.14" + runs-on: ${{ matrix.runs-on }} name: "Test eudsl-python-extras ${{ matrix.name }} ${{ matrix.python-version }}" diff --git a/.github/workflows/clean_releases.yml b/.github/workflows/clean_releases.yml index e9312db5..ebc94b56 100644 --- a/.github/workflows/clean_releases.yml +++ b/.github/workflows/clean_releases.yml @@ -21,7 +21,7 @@ jobs: uses: actions/checkout@v2 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6.0.0 with: python-version: "3.10" diff --git a/projects/eudsl-llvmpy/src/llvm/function.py b/projects/eudsl-llvmpy/src/llvm/function.py index 85c12061..a17f94cd 100644 --- a/projects/eudsl-llvmpy/src/llvm/function.py +++ b/projects/eudsl-llvmpy/src/llvm/function.py @@ -65,16 +65,17 @@ def __init__(self, body_builder, *, return_type=None, entry_bb_name="entry"): def _is_decl(self): # magic constant found from looking at the code for an empty fn + if sys.version_info.minor == 14: + return self.body_builder.__code__.co_code == b"\x80\x00R\x00#\x00" if sys.version_info.minor == 13: return self.body_builder.__code__.co_code == b"\x95\x00g\x00" - elif sys.version_info.minor == 12: + if sys.version_info.minor == 12: return self.body_builder.__code__.co_code == b"\x97\x00y\x00" - elif sys.version_info.minor == 11: + if sys.version_info.minor == 11: return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00" - elif sys.version_info.minor in {8, 9, 10}: + if sys.version_info.minor in {8, 9, 10}: return self.body_builder.__code__.co_code == b"d\x00S\x00" - else: - raise NotImplementedError(f"{sys.version_info.minor} not supported.") + raise NotImplementedError(f"{sys.version_info.minor} not supported.") def __str__(self): return str(f"{self.__class__} {self.__dict__}") diff --git a/projects/eudsl-python-extras/mlir/extras/ast/util.py b/projects/eudsl-python-extras/mlir/extras/ast/util.py index b176eecf..32f5989c 100644 --- a/projects/eudsl-python-extras/mlir/extras/ast/util.py +++ b/projects/eudsl-python-extras/mlir/extras/ast/util.py @@ -92,7 +92,7 @@ def replace_closure(code, new_closure: Dict): LOAD_DEREF = opmap["LOAD_DEREF"] # get the orig localplus that will be loaded from by the orig bytecode LOAD_DEREF arg_i - localsplus, localsplus_name_to_idx = get_localsplus_name_to_idx(code) + localsplus, _localsplus_name_to_idx = get_localsplus_name_to_idx(code) # closure vars go into co_freevars new_code = code.replace(co_freevars=tuple(new_closure.keys())) @@ -130,18 +130,23 @@ def reducer_override(self, obj): return super().reducer_override(obj) +def copy_object(obj): + # see https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L813 + # for how this trick is accomplished (dill and pickle both fail to pickle eg generic typevars) + with io.BytesIO() as file: + cp = MLIRTypePickler(file) + cp.dump(obj) + obj = cloudpickle.loads(file.getvalue()) + return obj + + # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard); # potentially more complete approach https://stackoverflow.com/a/56901529/9045206 def copy_func(f, new_closure: Dict = None): if new_closure is not None: code, closure = replace_closure(f.__code__, new_closure) else: - # see https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L813 - # for how this trick is accomplished (dill and pickle both fail to pickle eg generic typevars) - with io.BytesIO() as file: - cp = MLIRTypePickler(file) - cp.dump(f.__closure__) - closure = cloudpickle.loads(file.getvalue()) + closure = copy_object(f.__closure__) code = f.__code__ g = types.FunctionType( diff --git a/projects/eudsl-python-extras/mlir/extras/dialects/func.py b/projects/eudsl-python-extras/mlir/extras/dialects/func.py index 02fbabd7..774b724b 100644 --- a/projects/eudsl-python-extras/mlir/extras/dialects/func.py +++ b/projects/eudsl-python-extras/mlir/extras/dialects/func.py @@ -8,7 +8,7 @@ from functools import update_wrapper from typing import Optional, List, Union, TypeVar -from ..ast.util import copy_func +from ..ast.util import copy_func, copy_object from ..ast.py_type import PyTypeVarObject, _Ptr, PyObject from ..meta import op_region_builder from .. import types as T @@ -114,9 +114,7 @@ def prep_func_types(sig, return_types): return_types = list(return_types) assert all( isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in return_types - ), ( - f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}" - ) + ), f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}" input_types = [ p.annotation @@ -125,9 +123,7 @@ def prep_func_types(sig, return_types): ] assert all( isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in input_types - ), ( - f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}" - ) + ), f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}" user_loc = get_user_code_loc() # If ir.Context is none (like for deferred func emit) if user_loc is None: @@ -179,7 +175,7 @@ def __init__( self.call_op_ctor = call_op_ctor self.arg_attrs = arg_attrs self.res_attrs = res_attrs - self.generics = generics + self.generics = copy_object(generics) self.loc = loc self.ip = ip self._func_op = None @@ -200,9 +196,9 @@ def __init__( ) if self._is_decl(): - assert len(self.input_types) == len(sig.parameters), ( - f"func decl needs all input types annotated" - ) + assert len(self.input_types) == len( + sig.parameters + ), f"func decl needs all input types annotated" self.sym_visibility = "private" self.emit() @@ -210,16 +206,15 @@ def _is_decl(self): # magic constant found from looking at the code for an empty fn if sys.version_info.minor == 14: return self.body_builder.__code__.co_code == b"\x80\x00R\x00#\x00" - elif sys.version_info.minor == 13: + if sys.version_info.minor == 13: return self.body_builder.__code__.co_code == b"\x95\x00g\x00" - elif sys.version_info.minor == 12: + if sys.version_info.minor == 12: return self.body_builder.__code__.co_code == b"\x97\x00y\x00" - elif sys.version_info.minor == 11: + if sys.version_info.minor == 11: return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00" - elif sys.version_info.minor in {8, 9, 10}: + if sys.version_info.minor in {8, 9, 10}: return self.body_builder.__code__.co_code == b"d\x00S\x00" - else: - raise NotImplementedError(f"{sys.version_info.minor} not supported.") + raise NotImplementedError(f"{sys.version_info.minor} not supported.") def __str__(self): return str(f"{self.__class__} {self.__dict__}") @@ -302,6 +297,8 @@ def __call__(self, *call_args): return call(self.emit(*call_args), call_args) def __getitem__(self, item): + if not isinstance(item, tuple): + item = (item,) if self.generics is None: raise RuntimeError( "using a generic call requires the func be generic (i.e., have type_params)" @@ -326,22 +323,27 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): continue if k not in already_reified_type_params: raise RuntimeError( - f"typevar {k} not reified prior to evaluating dependent typevar {t}" + f"typevar {k} not reified prior to evaluating dependent typevar {tvar}" ) cvrs[k] = already_reified_type_params[k] unevaled_type_data = copy_func(unevaled_type_data, cvrs) return unevaled_type_data() - generics = copy.deepcopy(self.generics) - for i, t in enumerate(generics): + generics = copy_object(self.generics) + for i, tvar in enumerate(generics): + if not isinstance(tvar, TypeVar): + raise RuntimeError( + f"{i}th generic has probably already been reified as {tvar}; if you're using a global tvar for the generic, " + f"you should use a unique one for each generic function." + ) type_var_default = None if sys.version_info >= (3, 12): - type_var = PyTypeVarObject.from_object(t) + type_var = PyTypeVarObject.from_object(tvar) type_var_bound = type_var.bound - if sys.version_info >= (3, 13) and t.has_default(): + if sys.version_info >= (3, 13) and tvar.has_default(): type_var_default = type_var.default_value else: - type_var_bound = t.__bound__ + type_var_bound = tvar.__bound__ if bool(type_var_bound): # before 3.12 typevar was just a python class @@ -350,7 +352,7 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): type_var_bound = maybe_eval_type_data_closure_vals(type_var_bound) elif not bool(type_var_default): if i >= len(item): - raise RuntimeError(f"generic {t} must have concrete val") + raise RuntimeError(f"generic {tvar=} must have concrete val") if isinstance(item[i], Type): type_var_bound = "type" else: @@ -362,23 +364,32 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): val = type_var_default else: if i >= len(item): - raise RuntimeError(f"generic {t} must have concrete val") + raise RuntimeError(f"generic {tvar=} must have concrete val") val = item[i] - r = ReifiedTypeParams(t.__name__, val, type_var_bound) + r = ReifiedTypeParams(tvar.__name__, val, type_var_bound) reified_type_params.append(r) already_reified_type_params[r.name] = r.val - if t.__name__ in body_builder.__globals__: - body_builder.__globals__[t.__name__] = r.val + # replace the tvar in body_builder's global context with the reified val + if tvar.__name__ in body_builder.__globals__: + body_builder.__globals__[tvar.__name__] = r.val + # replace the tvar in body_builder's closure with the reified val if r.name in body_builder.__code__.co_freevars: free_i = body_builder.__code__.co_freevars.index(r.name) - assert body_builder.__closure__[free_i].cell_contents == t, ( - "typevars don't match" - ) + assert ( + body_builder.__closure__[free_i].cell_contents == tvar + ), "typevars don't match" body_builder.__closure__[free_i].cell_contents = r.val + name_mangled_generics = [] + for r in reified_type_params: + tvar, v = r.type, r.val + if callable(v): + v = v.__name__ + name_mangled_generics.append(f"{tvar}_{v}") + return FuncBase( body_builder, self.func_op_ctor, @@ -386,11 +397,7 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): self.call_op_ctor, return_types=self.return_types, sym_visibility=self.sym_visibility, - sym_name=( - self.func_name - + "_" - + "_".join([f"{r.type}_{r.val}" for r in reified_type_params]) - ), + sym_name=(self.func_name + "_" + "_".join(name_mangled_generics)), arg_attrs=self.arg_attrs, res_attrs=self.res_attrs, func_attrs=self.func_attrs, diff --git a/projects/eudsl-python-extras/mlir/extras/runtime/_passes_base.py b/projects/eudsl-python-extras/mlir/extras/runtime/_passes_base.py index 7c576ef4..58bec28f 100644 --- a/projects/eudsl-python-extras/mlir/extras/runtime/_passes_base.py +++ b/projects/eudsl-python-extras/mlir/extras/runtime/_passes_base.py @@ -6,7 +6,6 @@ import sys import tempfile from contextlib import ExitStack -from enum import StrEnum from io import StringIO from typing import List, Optional, Union @@ -14,6 +13,16 @@ from ...ir import Module, StringAttr from ...passmanager import PassManager + +try: + from enum import StrEnum +except ImportError: + from enum import Enum + + class StrEnum(str, Enum): + pass + + logger = logging.getLogger(__name__) diff --git a/projects/eudsl-python-extras/requirements.txt b/projects/eudsl-python-extras/requirements.txt index a1bbc682..6e31ad2c 100644 --- a/projects/eudsl-python-extras/requirements.txt +++ b/projects/eudsl-python-extras/requirements.txt @@ -1,4 +1,4 @@ PyYAML>=5.4.0 -bytecode @ git+https://github.com/MatthieuDartiailh/bytecode +bytecode>=0.17.0 cloudpickle>=3.0.0 numpy>=1.19.5, <=2.1.2 diff --git a/projects/eudsl-python-extras/tests/dialect/test_func.py b/projects/eudsl-python-extras/tests/dialect/test_func.py index e725db50..261c18e3 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_func.py +++ b/projects/eudsl-python-extras/tests/dialect/test_func.py @@ -205,15 +205,94 @@ def mat_product_kernel( one = arith.constant(1, dtype) mat_product_kernel[32, 32, 32, T.i32()].emit() + mat_product_kernel[32, 32, 32, T.f32()].emit() # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_i32(%[[VAL_0:.*]]: memref<32x32xi32>, %[[VAL_1:.*]]: memref<32x32xi32>, %[[VAL_2:.*]]: memref<32x32xi32>) { # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 # CHECK: return # CHECK: } + # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: return + # CHECK: } filecheck_with_comments(ctx.module) +def test_generics_callable(ctx: MLIRContext): + _op = TypeVar("_op") + + @func(generics=[_op]) + def mat_product_kernel1(): + one = arith.constant(1, T.f32()) + two = _op(one, one) + + @func(generics=[_op]) + def mat_product_kernel2(): + one = arith.constant(1, T.f32()) + two = _op(one, one) + + mat_product_kernel1[arith.maximumf].emit() + mat_product_kernel2[arith.minimumf].emit() + mat_product_kernel2[arith.maximumf].emit() + + # CHECK: func.func @mat_product_kernel1_function_maximumf() { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: %0 = arith.maximumf %cst, %cst : f32 + # CHECK: return + # CHECK: } + # CHECK: func.func @mat_product_kernel2_function_minimumf() { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: %0 = arith.minimumf %cst, %cst : f32 + # CHECK: return + # CHECK: } + # CHECK: func.func @mat_product_kernel2_function_maximumf() { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: %0 = arith.maximumf %cst, %cst : f32 + # CHECK: return + # CHECK: } + + filecheck_with_comments(ctx.module) + + +_op = TypeVar("_op") + + +def test_global_closures(ctx: MLIRContext): + globals()["_op"] = TypeVar("_op") + + @func(generics=[_op]) + def _generic_pool2d_scf(a: T.f32(), b: T.f32()): + _op(a, b) + + _maxpool2d_scf = _generic_pool2d_scf[arith.maximumf] + + @func(generics=[_op]) + def _generic_pool3d_scf( + a: T.f32(), + b: T.f32(), + ): + _op(a, b) + + with pytest.raises( + RuntimeError, + match="0th generic has probably already been reified as ; if you're using a global tvar for the generic," + " you should use a unique one for each generic function.", + ): + _maxpool3d_scf = _generic_pool3d_scf[arith.maximumf] + + _op1 = TypeVar("_op1") + + @func(generics=[_op1]) + def _generic_pool3d_scf( + a: T.f32(), + b: T.f32(), + ): + _op(a, b) + + _maxpool3d_scf = _generic_pool3d_scf[arith.maximumf] + + def test_generics_with_canonicalizations(ctx: MLIRContext): generics = M, K, N, dtype = list(map(TypeVar, ["M", "K", "N", "dtype"])) diff --git a/projects/eudsl-python-extras/tests/dialect/test_linalg.py b/projects/eudsl-python-extras/tests/dialect/test_linalg.py index 0fd5c32c..d76fa93e 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_linalg.py +++ b/projects/eudsl-python-extras/tests/dialect/test_linalg.py @@ -18,6 +18,7 @@ filecheck_with_comments, mlir_ctx as ctx, ) +from mlir.extras.runtime.passes import Pipeline, run_pipeline # needed since the fix isn't defined here nor conftest.py pytest.mark.usefixtures("ctx") @@ -134,3 +135,118 @@ def maxpool3d( # CHECK: return # CHECK: } filecheck_with_comments(maxpool3d_k) + + +def test_pooling_ncdhw_max_parallel(ctx: MLIRContext): + S = ShapedType.get_dynamic_size() + + generics = ( + kernel_size_0, + kernel_size_1, + kernel_size_2, + stride_0, + stride_1, + stride_2, + dilation_0, + dilation_1, + dilation_2, + ) = list( + map( + TypeVar, + [ + "kernel_size_0", + "kernel_size_1", + "kernel_size_2", + "stride_0", + "stride_1", + "stride_2", + "dilation_0", + "dilation_1", + "dilation_2", + ], + ) + ) + + @func( + generics=( + kernel_size_0, + kernel_size_1, + kernel_size_2, + stride_0, + stride_1, + stride_2, + dilation_0, + dilation_1, + dilation_2, + ) + ) + def maxpool3d( + input: T.memref(S, S, S, S, S, T.f32()), + output: T.memref(S, S, S, S, S, T.f32()), + ): + kernel_shape_surrogate = memref.alloca( + (kernel_size_0, kernel_size_1, kernel_size_2), + T.f32(), + ) + + linalg.pooling_ncdhw_max( + input, + kernel_shape_surrogate, + output, + strides=[stride_0, stride_1, stride_2], + dilations=[dilation_0, dilation_1, dilation_2], + ) + + kernel_sizes = [1, 2, 3] + strides = [4, 5, 6] + dilations = [7, 8, 9] + maxpool3d_k = maxpool3d[ + kernel_sizes[0], + kernel_sizes[1], + kernel_sizes[2], + strides[0], + strides[1], + strides[2], + dilations[0], + dilations[1], + dilations[2], + ].emit() + module = run_pipeline( + ctx.module, + Pipeline().bufferize().Func(Pipeline().convert_linalg_to_parallel_loops()), + ) + # CHECK: #map = affine_map<(d0, d1) -> (d0 * 4 + d1 * 7)> + # CHECK: #map1 = affine_map<(d0, d1) -> (d0 * 5 + d1 * 8)> + # CHECK: #map2 = affine_map<(d0, d1) -> (d0 * 6 + d1 * 9)> + # CHECK: module { + # CHECK: func.func @maxpool3d_int_1_int_2_int_3_int_4_int_5_int_6_int_7_int_8_int_9(%arg0: memref, %arg1: memref) { + # CHECK: %c4 = arith.constant 4 : index + # CHECK: %c3 = arith.constant 3 : index + # CHECK: %c2 = arith.constant 2 : index + # CHECK: %c1 = arith.constant 1 : index + # CHECK: %c0 = arith.constant 0 : index + # CHECK: %dim = memref.dim %arg0, %c0 : memref + # CHECK: %dim_0 = memref.dim %arg0, %c1 : memref + # CHECK: %dim_1 = memref.dim %arg1, %c2 : memref + # CHECK: %dim_2 = memref.dim %arg1, %c3 : memref + # CHECK: %dim_3 = memref.dim %arg1, %c4 : memref + # CHECK: scf.parallel (%arg2, %arg3, %arg4, %arg5, %arg6) = (%c0, %c0, %c0, %c0, %c0) to (%dim, %dim_0, %dim_1, %dim_2, %dim_3) step (%c1, %c1, %c1, %c1, %c1) { + # CHECK: scf.for %arg7 = %c0 to %c1 step %c1 { + # CHECK: scf.for %arg8 = %c0 to %c2 step %c1 { + # CHECK: scf.for %arg9 = %c0 to %c3 step %c1 { + # CHECK: %0 = affine.apply #map(%arg4, %arg7) + # CHECK: %1 = affine.apply #map1(%arg5, %arg8) + # CHECK: %2 = affine.apply #map2(%arg6, %arg9) + # CHECK: %3 = memref.load %arg0[%arg2, %arg3, %0, %1, %2] : memref + # CHECK: %4 = memref.load %arg1[%arg2, %arg3, %arg4, %arg5, %arg6] : memref + # CHECK: %5 = arith.maximumf %3, %4 : f32 + # CHECK: memref.store %5, %arg1[%arg2, %arg3, %arg4, %arg5, %arg6] : memref + # CHECK: } + # CHECK: } + # CHECK: } + # CHECK: scf.reduce + # CHECK: } + # CHECK: return + # CHECK: } + # CHECK: } + filecheck_with_comments(module)