From 461a238ae75cc2ff00a5a75c7e087a7ce817b911 Mon Sep 17 00:00:00 2001 From: makslevental Date: Thu, 20 Nov 2025 12:04:46 -0800 Subject: [PATCH 1/2] [eudsl-python-extras] fix more func stuff --- .../mlir/extras/dialects/func.py | 51 ++-- .../mlir/extras/dialects/gpu.py | 5 +- .../tests/dialect/test_func.py | 264 ++---------------- .../tests/dialect/test_linalg.py | 148 +++++----- 4 files changed, 121 insertions(+), 347 deletions(-) diff --git a/projects/eudsl-python-extras/mlir/extras/dialects/func.py b/projects/eudsl-python-extras/mlir/extras/dialects/func.py index 774b724b..357f65c5 100644 --- a/projects/eudsl-python-extras/mlir/extras/dialects/func.py +++ b/projects/eudsl-python-extras/mlir/extras/dialects/func.py @@ -1,14 +1,13 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import copy import inspect import sys from dataclasses import dataclass from functools import update_wrapper from typing import Optional, List, Union, TypeVar -from ..ast.util import copy_func, copy_object +from ..ast.util import copy_func from ..ast.py_type import PyTypeVarObject, _Ptr, PyObject from ..meta import op_region_builder from .. import types as T @@ -175,7 +174,7 @@ def __init__( self.call_op_ctor = call_op_ctor self.arg_attrs = arg_attrs self.res_attrs = res_attrs - self.generics = copy_object(generics) + self.generics = generics self.loc = loc self.ip = ip self._func_op = None @@ -323,33 +322,24 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): continue if k not in already_reified_type_params: raise RuntimeError( - f"typevar {k} not reified prior to evaluating dependent typevar {tvar}" + f"typevar {k} not reified prior to evaluating dependent typevar {v}" ) cvrs[k] = already_reified_type_params[k] unevaled_type_data = copy_func(unevaled_type_data, cvrs) return unevaled_type_data() - generics = copy_object(self.generics) - for i, tvar in enumerate(generics): - if not isinstance(tvar, TypeVar): - raise RuntimeError( - f"{i}th generic has probably already been reified as {tvar}; if you're using a global tvar for the generic, " - f"you should use a unique one for each generic function." - ) + for i, tvar in enumerate(self.generics): + if tvar.__name__ in body_builder.__globals__: + raise RuntimeError("global typevars for generics are not supported") + type_var_default = None - if sys.version_info >= (3, 12): - type_var = PyTypeVarObject.from_object(tvar) - type_var_bound = type_var.bound - if sys.version_info >= (3, 13) and tvar.has_default(): - type_var_default = type_var.default_value - else: - type_var_bound = tvar.__bound__ + type_var = PyTypeVarObject.from_object(tvar) + type_var_bound = type_var.bound + if sys.version_info >= (3, 13) and tvar.has_default(): + type_var_default = type_var.default_value if bool(type_var_bound): - # before 3.12 typevar was just a python class - # https://github.com/python/cpython/blob/3.11/Lib/typing.py#L966 - if sys.version_info >= (3, 12): - type_var_bound = maybe_eval_type_data_closure_vals(type_var_bound) + type_var_bound = maybe_eval_type_data_closure_vals(type_var_bound) elif not bool(type_var_default): if i >= len(item): raise RuntimeError(f"generic {tvar=} must have concrete val") @@ -372,15 +362,13 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): reified_type_params.append(r) already_reified_type_params[r.name] = r.val - # replace the tvar in body_builder's global context with the reified val - if tvar.__name__ in body_builder.__globals__: - body_builder.__globals__[tvar.__name__] = r.val - # replace the tvar in body_builder's closure with the reified val + # only in the closure if used in the body if r.name in body_builder.__code__.co_freevars: free_i = body_builder.__code__.co_freevars.index(r.name) - assert ( - body_builder.__closure__[free_i].cell_contents == tvar - ), "typevars don't match" + if body_builder.__closure__[free_i].cell_contents != tvar: + raise RuntimeError( + f"typevars don't match: {id(body_builder.__closure__[free_i].cell_contents)=}, {id(tvar)=}" + ) body_builder.__closure__[free_i].cell_contents = r.val name_mangled_generics = [] @@ -419,12 +407,9 @@ def func( func_attrs=None, function_type=None, emit=False, - generics=None, loc=None, ip=None, ) -> FuncBase: - if generics is None and hasattr(f, "__type_params__") and f.__type_params__: - generics = f.__type_params__ func_ = FuncBase( body_builder=f, func_op_ctor=FuncOp.__base__, @@ -436,7 +421,7 @@ def func( res_attrs=res_attrs, func_attrs=func_attrs, function_type=function_type, - generics=generics, + generics=getattr(f, "__type_params__", None), loc=loc, ip=ip, ) diff --git a/projects/eudsl-python-extras/mlir/extras/dialects/gpu.py b/projects/eudsl-python-extras/mlir/extras/dialects/gpu.py index 8239a7e1..350f296e 100644 --- a/projects/eudsl-python-extras/mlir/extras/dialects/gpu.py +++ b/projects/eudsl-python-extras/mlir/extras/dialects/gpu.py @@ -439,13 +439,10 @@ def func( res_attrs=None, func_attrs=None, emit=False, - generics=None, loc=None, ip=None, emit_grid=False, ) -> Grid: - if generics is None and hasattr(f, "__type_params__") and f.__type_params__: - generics = f.__type_params__ func_ = GPUFunc( body_builder=f, func_op_ctor=GPUFuncOp, @@ -455,7 +452,7 @@ def func( arg_attrs=arg_attrs, res_attrs=res_attrs, func_attrs=func_attrs, - generics=generics, + generics=getattr(f, "__type_params__", None), loc=loc, ip=ip, ) diff --git a/projects/eudsl-python-extras/tests/dialect/test_func.py b/projects/eudsl-python-extras/tests/dialect/test_func.py index 261c18e3..768ee05f 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_func.py +++ b/projects/eudsl-python-extras/tests/dialect/test_func.py @@ -151,19 +151,15 @@ def foo1(): filecheck_with_comments(mod_ctx.module) -generics = M, K, N, dtype = list(map(TypeVar, ["M", "K", "N", "dtype"])) - - -@func(generics=list(map(TypeVar, ["M", "N"]))) -def matmul_i32_i32( - A: "T.memref(M, N, T.i32())", - B: "T.memref(M, N, T.i32())", - C: "T.memref(M, N, T.i32())", -): - linalg.matmul(A, B, C) - - def test_func_no_context_2(ctx: MLIRContext): + @func + def matmul_i32_i32[M, N]( + A: "T.memref(M, N, T.i32())", + B: "T.memref(M, N, T.i32())", + C: "T.memref(M, N, T.i32())", + ): + linalg.matmul(A, B, C) + matmul_i32_i32[16, 16].emit() # CHECK: func.func @matmul_i32_i32_int_16_int_16(%[[VAL_0:.*]]: memref<16x16xi32>, %[[VAL_1:.*]]: memref<16x16xi32>, %[[VAL_2:.*]]: memref<16x16xi32>) { @@ -175,8 +171,8 @@ def test_func_no_context_2(ctx: MLIRContext): def test_generics_just_args(ctx: MLIRContext): - @func(generics=generics) - def mat_product_kernel( + @func + def mat_product_kernel[M, K, N, dtype]( A: "T.memref(M, K, dtype)", B: "T.memref(K, N, dtype)", C: "T.memref(M, N, dtype)", @@ -194,10 +190,9 @@ def mat_product_kernel( def test_generics_closure(ctx: MLIRContext): - generics = M, K, N, dtype = list(map(TypeVar, ["M", "K", "N", "dtype"])) - @func(generics=generics) - def mat_product_kernel( + @func + def mat_product_kernel[M, K, N, dtype]( A: "T.memref(M, K, dtype)", B: "T.memref(K, N, dtype)", C: "T.memref(M, N, dtype)", @@ -222,19 +217,19 @@ def mat_product_kernel( def test_generics_callable(ctx: MLIRContext): _op = TypeVar("_op") - @func(generics=[_op]) - def mat_product_kernel1(): + @func + def mat_product_kernel1[_op](): one = arith.constant(1, T.f32()) two = _op(one, one) - @func(generics=[_op]) - def mat_product_kernel2(): + @func + def mat_product_kernel2[_op](): one = arith.constant(1, T.f32()) two = _op(one, one) - mat_product_kernel1[arith.maximumf].emit() - mat_product_kernel2[arith.minimumf].emit() - mat_product_kernel2[arith.maximumf].emit() + mat_product_kernel1[arith.maximumf,].emit() + mat_product_kernel2[arith.minimumf,].emit() + mat_product_kernel2[arith.maximumf,].emit() # CHECK: func.func @mat_product_kernel1_function_maximumf() { # CHECK: %cst = arith.constant 1.000000e+00 : f32 @@ -255,50 +250,11 @@ def mat_product_kernel2(): filecheck_with_comments(ctx.module) -_op = TypeVar("_op") - - -def test_global_closures(ctx: MLIRContext): - globals()["_op"] = TypeVar("_op") - - @func(generics=[_op]) - def _generic_pool2d_scf(a: T.f32(), b: T.f32()): - _op(a, b) - - _maxpool2d_scf = _generic_pool2d_scf[arith.maximumf] - - @func(generics=[_op]) - def _generic_pool3d_scf( - a: T.f32(), - b: T.f32(), - ): - _op(a, b) - - with pytest.raises( - RuntimeError, - match="0th generic has probably already been reified as ; if you're using a global tvar for the generic," - " you should use a unique one for each generic function.", - ): - _maxpool3d_scf = _generic_pool3d_scf[arith.maximumf] - - _op1 = TypeVar("_op1") - - @func(generics=[_op1]) - def _generic_pool3d_scf( - a: T.f32(), - b: T.f32(), - ): - _op(a, b) - - _maxpool3d_scf = _generic_pool3d_scf[arith.maximumf] - - def test_generics_with_canonicalizations(ctx: MLIRContext): - generics = M, K, N, dtype = list(map(TypeVar, ["M", "K", "N", "dtype"])) - @func(generics=generics) + @func @canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) - def mat_product_kernel( + def mat_product_kernel[M, K, N, dtype]( A: "T.memref(M, K, dtype)", B: "T.memref(K, N, dtype)", C: "T.memref(M, N, dtype)", @@ -434,29 +390,10 @@ def demo_fun1(a, b): def test_name_mangling(ctx: MLIRContext): _S = ShapedType.get_dynamic_size() - generics = ( - kernel_size_0, - kernel_size_1, - stride_0, - stride_1, - dilation_0, - dilation_1, - ) = list( - map( - TypeVar, - [ - "kernel_size_0", - "kernel_size_1", - "stride_0", - "stride_1", - "dilation_0", - "dilation_1", - ], - ) - ) - - @func(generics=generics) - def maxpool2d( + @func + def maxpool2d[ + kernel_size_0, kernel_size_1, stride_0, stride_1, dilation_0, dilation_1 + ]( input: T.memref(_S, _S, _S, _S, T.f32()), output: T.memref(_S, _S, _S, _S, T.f32()), ): @@ -480,154 +417,3 @@ def maxpool2d( # CHECK: return # CHECK: } filecheck_with_comments(maxpool2d_k) - - -@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher") -def test_generics(ctx: MLIRContext): - # dodge <3.12 parser that doesn't support square brackets generics - exec( - dedent( - """\ - @func - def mat_product_kernel[ - M, K, N, dtype - ]( - A: "T.memref(M, K, dtype)", - B: "T.memref(K, N, dtype)", - C: "T.memref(M, N, dtype)", - x: T.index(), - y: T.index() - ): - - one = arith.constant(1.0, type=dtype) - tmp = arith.constant(0, type=dtype) - for k, tmp, _ in scf.range_(K, iter_args=[tmp]): - tmp += A[x, k] * B[k, y] - tmp = scf.yield_(tmp) - C[x, y] = tmp + one - - globals()["mat_product_kernel"] = mat_product_kernel - """ - ) - ) - - matk = mat_product_kernel[32, 32, 32, T.f32()].emit() - - # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>, %arg3: index, %arg4: index) { - # CHECK: %cst = arith.constant 1.000000e+00 : f32 - # CHECK: %cst_0 = arith.constant 0.000000e+00 : f32 - # CHECK: %c0 = arith.constant 0 : index - # CHECK: %c32 = arith.constant 32 : index - # CHECK: %c1 = arith.constant 1 : index - # CHECK: %0 = scf.for %arg5 = %c0 to %c32 step %c1 iter_args(%arg6 = %cst_0) -> (f32) { - # CHECK: %2 = memref.load %arg0[%arg3, %arg5] : memref<32x32xf32> - # CHECK: %3 = memref.load %arg1[%arg5, %arg4] : memref<32x32xf32> - # CHECK: %4 = arith.mulf %2, %3 : f32 - # CHECK: %5 = arith.addf %arg6, %4 : f32 - # CHECK: scf.yield %5 : f32 - # CHECK: } - # CHECK: %1 = arith.addf %0, %cst : f32 - # CHECK: memref.store %1, %arg2[%arg3, %arg4] : memref<32x32xf32> - # CHECK: return - # CHECK: } - filecheck_with_comments(matk) - - -@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher") -def test_generic_type_var_closure_patching(ctx: MLIRContext): - # dodge <3.12 parser that doesn't support square brackets generics - exec( - dedent( - """\ - from mlir.extras.ast.py_type import PyTypeVarObject - - def fun2[foo, bar, A: foo + bar](): - print(A.__bound__) - - - A_type_param = fun2.__type_params__[2] - - a = PyTypeVarObject.from_object(A_type_param) - a_something = a.bound.contents.into_object() - a_something.__closure__[0].cell_contents = 5 - a_something.__closure__[1].cell_contents = 7 - - fun2() - """ - ) - ) - - -@pytest.mark.skipif( - sys.version_info < (3, 13) or platform.system() == "Windows", - reason="requires python3.13 or higher (and windows can't find the source file)", -) -def test_generic_type_var_closure_patching_dependent_generics(ctx: MLIRContext): - # dodge <3.12 parser that doesn't support square brackets generics - # but also need a real file here because rewriter needs source... - src = dedent( - """\ - from mlir.extras.dialects import arith, scf - from mlir.extras.ast.canonicalize import canonicalize - import mlir.extras.types as T - - @func - def test_plain[ - M, - K, - N, - dtype, - A_t = T.memref(M, K, dtype), - B_t = T.memref(K, N, dtype), - C_t = T.memref(M, N, dtype), - ](A: A_t, B: B_t, C: C_t): - one = arith.constant(1.0, type=dtype) - - @func - @canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) - def test_2_with_rewrite[ - M, - K, - N, - dtype, - A_t = T.memref(M, K, dtype), - B_t = T.memref(K, N, dtype), - C_t = T.memref(M, N, dtype), - ](A: A_t, B: B_t, C: C_t): - one = arith.constant(1.0, type=dtype) - - globals()["test_plain"] = test_plain - globals()["test_2_with_rewrite"] = test_2_with_rewrite - """ - ) - - with tempfile.NamedTemporaryFile(mode="w") as tmp: - tmp.write(src) - tmp.flush() - code = compile(src, tmp.name, "exec") - exec(code, globals(), locals()) - - test_plain[1, 2, 3, T.f32()].emit() - test_2_with_rewrite[1, 2, 3, T.f32()].emit() - - test_plain[4, 5, 6, T.f16()].emit() - test_2_with_rewrite[4, 5, 6, T.f16()].emit() - - # CHECK: func.func @"test_plain_int_1_int_2_int_3_type_f32_MemRefType_memref<1x2xf32>_MemRefType_memref<2x3xf32>_MemRefType_memref<1x3xf32>"(%arg0: memref<1x2xf32>, %arg1: memref<2x3xf32>, %arg2: memref<1x3xf32>) { - # CHECK: %cst = arith.constant 1.000000e+00 : f32 - # CHECK: return - # CHECK: } - # CHECK: func.func @"test_2_with_rewrite_int_1_int_2_int_3_type_f32_MemRefType_memref<1x2xf32>_MemRefType_memref<2x3xf32>_MemRefType_memref<1x3xf32>"(%arg0: memref<1x2xf32>, %arg1: memref<2x3xf32>, %arg2: memref<1x3xf32>) { - # CHECK: %cst = arith.constant 1.000000e+00 : f32 - # CHECK: return - # CHECK: } - # CHECK: func.func @"test_plain_int_4_int_5_int_6_type_f16_MemRefType_memref<4x5xf16>_MemRefType_memref<5x6xf16>_MemRefType_memref<4x6xf16>"(%arg0: memref<4x5xf16>, %arg1: memref<5x6xf16>, %arg2: memref<4x6xf16>) { - # CHECK: %cst = arith.constant 1.000000e+00 : f16 - # CHECK: return - # CHECK: } - # CHECK: func.func @"test_2_with_rewrite_int_4_int_5_int_6_type_f16_MemRefType_memref<4x5xf16>_MemRefType_memref<5x6xf16>_MemRefType_memref<4x6xf16>"(%arg0: memref<4x5xf16>, %arg1: memref<5x6xf16>, %arg2: memref<4x6xf16>) { - # CHECK: %cst = arith.constant 1.000000e+00 : f16 - # CHECK: return - # CHECK: } - ctx.module.operation.verify() - filecheck_with_comments(ctx.module) diff --git a/projects/eudsl-python-extras/tests/dialect/test_linalg.py b/projects/eudsl-python-extras/tests/dialect/test_linalg.py index d76fa93e..9b20aa5b 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_linalg.py +++ b/projects/eudsl-python-extras/tests/dialect/test_linalg.py @@ -1,7 +1,6 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import TypeVar import mlir.extras.types as T import pytest @@ -51,10 +50,76 @@ def test_np_constructor(ctx: MLIRContext): filecheck_with_comments(ctx.module) +def test_pooling_nchw_max(ctx: MLIRContext): + S = ShapedType.get_dynamic_size() + + @func + def maxpool2d[ + kernel_size_0, kernel_size_1, stride_0, stride_1, dilation_0, dilation_1 + ]( + input: T.memref(S, S, S, S, T.f32()), + output: T.memref(S, S, S, S, T.f32()), + ): + kernel_shape_surrogate = memref.alloca( + (kernel_size_0, kernel_size_1), + T.f32(), + ) + + linalg.pooling_nchw_max( + input, + kernel_shape_surrogate, + output, + strides=[stride_0, stride_1], + dilations=[dilation_0, dilation_1], + ) + + kernel_sizes = [2, 3] + strides = [4, 5] + dilations = [6, 7] + maxpool2d_k = maxpool2d[ + kernel_sizes[0], + kernel_sizes[1], + strides[0], + strides[1], + dilations[0], + dilations[1], + ].emit() + module = run_pipeline( + ctx.module, + Pipeline().bufferize().Func(Pipeline().convert_linalg_to_parallel_loops()), + ) + # CHECK: func.func @maxpool2d_int_2_int_3_int_4_int_5_int_6_int_7(%arg0: memref, %arg1: memref) { + # CHECK: %c3 = arith.constant 3 : index + # CHECK: %c2 = arith.constant 2 : index + # CHECK: %c1 = arith.constant 1 : index + # CHECK: %c0 = arith.constant 0 : index + # CHECK: %dim = memref.dim %arg0, %c0 : memref + # CHECK: %dim_0 = memref.dim %arg0, %c1 : memref + # CHECK: %dim_1 = memref.dim %arg1, %c2 : memref + # CHECK: %dim_2 = memref.dim %arg1, %c3 : memref + # CHECK: scf.parallel (%arg2, %arg3, %arg4, %arg5) = (%c0, %c0, %c0, %c0) to (%dim, %dim_0, %dim_1, %dim_2) step (%c1, %c1, %c1, %c1) { + # CHECK: scf.for %arg6 = %c0 to %c2 step %c1 { + # CHECK: scf.for %arg7 = %c0 to %c3 step %c1 { + # CHECK: %0 = affine.apply #map(%arg4, %arg6) + # CHECK: %1 = affine.apply #map1(%arg5, %arg7) + # CHECK: %2 = memref.load %arg0[%arg2, %arg3, %0, %1] : memref + # CHECK: %3 = memref.load %arg1[%arg2, %arg3, %arg4, %arg5] : memref + # CHECK: %4 = arith.maximumf %3, %2 : f32 + # CHECK: memref.store %4, %arg1[%arg2, %arg3, %arg4, %arg5] : memref + # CHECK: } + # CHECK: } + # CHECK: scf.reduce + # CHECK: } + # CHECK: return + # CHECK: } + filecheck_with_comments(module) + + def test_pooling_ncdhw_max(ctx: MLIRContext): S = ShapedType.get_dynamic_size() - generics = ( + @func + def maxpool3d[ kernel_size_0, kernel_size_1, kernel_size_2, @@ -64,37 +129,7 @@ def test_pooling_ncdhw_max(ctx: MLIRContext): dilation_0, dilation_1, dilation_2, - ) = list( - map( - TypeVar, - [ - "kernel_size_0", - "kernel_size_1", - "kernel_size_2", - "stride_0", - "stride_1", - "stride_2", - "dilation_0", - "dilation_1", - "dilation_2", - ], - ) - ) - - @func( - generics=( - kernel_size_0, - kernel_size_1, - kernel_size_2, - stride_0, - stride_1, - stride_2, - dilation_0, - dilation_1, - dilation_2, - ) - ) - def maxpool3d( + ]( input: T.memref(S, S, S, S, S, T.f32()), output: T.memref(S, S, S, S, S, T.f32()), ): @@ -111,9 +146,9 @@ def maxpool3d( dilations=[dilation_0, dilation_1, dilation_2], ) - kernel_sizes = [1, 2, 2] - strides = [1, 1, 1] - dilations = [1, 1, 1] + kernel_sizes = [1, 2, 3] + strides = [5, 6, 7] + dilations = [7, 8, 9] maxpool3d_k = maxpool3d[ kernel_sizes[0], kernel_sizes[1], @@ -125,9 +160,9 @@ def maxpool3d( dilations[1], dilations[2], ].emit() - # CHECK: func.func @maxpool3d_int_1_int_2_int_2_int_1_int_1_int_1_int_1_int_1_int_1(%arg0: memref, %arg1: memref) { - # CHECK: %alloca = memref.alloca() : memref<1x2x2xf32> - # CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2 + d5, d3 + d6, d4 + d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %alloca : memref, memref<1x2x2xf32>) outs(%arg1 : memref) { + # CHECK: func.func @maxpool3d_int_1_int_2_int_3_int_5_int_6_int_7_int_7_int_8_int_9(%arg0: memref, %arg1: memref) { + # CHECK: %alloca = memref.alloca() : memref<1x2x3xf32> + # CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2 * 5 + d5 * 7, d3 * 6 + d6 * 8, d4 * 7 + d7 * 9)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %alloca : memref, memref<1x2x3xf32>) outs(%arg1 : memref) { # CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): # CHECK: %0 = arith.maximumf %in, %out : f32 # CHECK: linalg.yield %0 : f32 @@ -140,7 +175,8 @@ def maxpool3d( def test_pooling_ncdhw_max_parallel(ctx: MLIRContext): S = ShapedType.get_dynamic_size() - generics = ( + @func + def maxpool3d[ kernel_size_0, kernel_size_1, kernel_size_2, @@ -150,37 +186,7 @@ def test_pooling_ncdhw_max_parallel(ctx: MLIRContext): dilation_0, dilation_1, dilation_2, - ) = list( - map( - TypeVar, - [ - "kernel_size_0", - "kernel_size_1", - "kernel_size_2", - "stride_0", - "stride_1", - "stride_2", - "dilation_0", - "dilation_1", - "dilation_2", - ], - ) - ) - - @func( - generics=( - kernel_size_0, - kernel_size_1, - kernel_size_2, - stride_0, - stride_1, - stride_2, - dilation_0, - dilation_1, - dilation_2, - ) - ) - def maxpool3d( + ]( input: T.memref(S, S, S, S, S, T.f32()), output: T.memref(S, S, S, S, S, T.f32()), ): From dc3b71a5010db0c7b45fd500de1bc4beba651c95 Mon Sep 17 00:00:00 2001 From: makslevental Date: Fri, 21 Nov 2025 11:17:24 -0800 Subject: [PATCH 2/2] [eudsl-python-extras] fix more func stuff --- ...build_test_release_eudsl_python_extras.yml | 10 +- .../mlir/extras/ast/util.py | 1 + .../mlir/extras/dialects/func.py | 5 +- .../mlir/extras/dialects/gpu.py | 5 +- .../tests/dialect/test_func.py | 239 +------- .../tests/dialect/test_generics.py | 573 ++++++++++++++++++ .../tests/dialect/test_gpu.py | 207 +------ .../tests/dialect/test_linalg.py | 212 ------- 8 files changed, 602 insertions(+), 650 deletions(-) create mode 100644 projects/eudsl-python-extras/tests/dialect/test_generics.py diff --git a/.github/workflows/build_test_release_eudsl_python_extras.yml b/.github/workflows/build_test_release_eudsl_python_extras.yml index b989b8d3..b8c4795e 100644 --- a/.github/workflows/build_test_release_eudsl_python_extras.yml +++ b/.github/workflows/build_test_release_eudsl_python_extras.yml @@ -64,6 +64,7 @@ jobs: - name: "Build eudsl-python-extras sdist" run: | + SHA_SHORT="$(git rev-parse --short HEAD)" WHEEL_VERSION="$(date +'%Y%m%d.%H%M')+$SHA_SHORT" pushd projects/eudsl-python-extras @@ -175,7 +176,14 @@ jobs: run: python -m pip install dist/eudsl_python_extras*.tar.gz - name: "Test eudsl-python-extras" - run: python -m pytest projects/eudsl-python-extras/tests + run: | + + IGNORE="" + if [[ $(python -c "print(__import__('sys').version_info < (3, 13))") == "True" ]]; then + IGNORE="--ignore projects/eudsl-python-extras/tests/dialect/test_generics.py" + fi + + python -m pytest projects/eudsl-python-extras/tests $IGNORE release-eudsl-python-extras: diff --git a/projects/eudsl-python-extras/mlir/extras/ast/util.py b/projects/eudsl-python-extras/mlir/extras/ast/util.py index 32f5989c..c5f29267 100644 --- a/projects/eudsl-python-extras/mlir/extras/ast/util.py +++ b/projects/eudsl-python-extras/mlir/extras/ast/util.py @@ -12,6 +12,7 @@ from bytecode import ConcreteBytecode from cloudpickle import cloudpickle + from ...ir import Type diff --git a/projects/eudsl-python-extras/mlir/extras/dialects/func.py b/projects/eudsl-python-extras/mlir/extras/dialects/func.py index 357f65c5..c29043d3 100644 --- a/projects/eudsl-python-extras/mlir/extras/dialects/func.py +++ b/projects/eudsl-python-extras/mlir/extras/dialects/func.py @@ -7,10 +7,10 @@ from functools import update_wrapper from typing import Optional, List, Union, TypeVar -from ..ast.util import copy_func +from .. import types as T from ..ast.py_type import PyTypeVarObject, _Ptr, PyObject +from ..ast.util import copy_func from ..meta import op_region_builder -from .. import types as T from ..util import get_user_code_loc, make_maybe_no_args_decorator from ...dialects._ods_common import get_op_result_or_op_results from ...dialects.func import * @@ -26,7 +26,6 @@ Value, ) - _call = call diff --git a/projects/eudsl-python-extras/mlir/extras/dialects/gpu.py b/projects/eudsl-python-extras/mlir/extras/dialects/gpu.py index 350f296e..7dfda437 100644 --- a/projects/eudsl-python-extras/mlir/extras/dialects/gpu.py +++ b/projects/eudsl-python-extras/mlir/extras/dialects/gpu.py @@ -5,7 +5,6 @@ from functools import partial from typing import Any, List, Optional, Tuple, Union - from .func import FuncBase from .. import types as T from ..meta import ( @@ -25,6 +24,10 @@ get_op_result_or_op_results, ) from ...dialects.gpu import * + +del constant +# constant needs to be below gpu import because it needs to shadow upstream's arith.constant +# noinspection PyUnusedImports from .arith import constant from ...ir import ( ArrayAttr, diff --git a/projects/eudsl-python-extras/tests/dialect/test_func.py b/projects/eudsl-python-extras/tests/dialect/test_func.py index 768ee05f..7d9408aa 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_func.py +++ b/projects/eudsl-python-extras/tests/dialect/test_func.py @@ -2,19 +2,14 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import inspect -import platform import sys -import tempfile -from textwrap import dedent -from typing import TypeVar -import pytest import mlir.extras.types as T -from mlir.extras.ast.canonicalize import canonicalize +import pytest +from mlir.ir import FunctionType + from mlir.extras.context import mlir_mod_ctx, RAIIMLIRContextModule -from mlir.extras.dialects import linalg, arith, scf, memref from mlir.extras.dialects.arith import constant -from mlir.ir import FunctionType, Module, ShapedType from mlir.extras.dialects.func import func # noinspection PyUnresolvedReferences @@ -151,202 +146,6 @@ def foo1(): filecheck_with_comments(mod_ctx.module) -def test_func_no_context_2(ctx: MLIRContext): - @func - def matmul_i32_i32[M, N]( - A: "T.memref(M, N, T.i32())", - B: "T.memref(M, N, T.i32())", - C: "T.memref(M, N, T.i32())", - ): - linalg.matmul(A, B, C) - - matmul_i32_i32[16, 16].emit() - - # CHECK: func.func @matmul_i32_i32_int_16_int_16(%[[VAL_0:.*]]: memref<16x16xi32>, %[[VAL_1:.*]]: memref<16x16xi32>, %[[VAL_2:.*]]: memref<16x16xi32>) { - # CHECK: linalg.matmul {cast = #linalg.type_fn} ins(%[[VAL_0]], %[[VAL_1]] : memref<16x16xi32>, memref<16x16xi32>) outs(%[[VAL_2]] : memref<16x16xi32>) - # CHECK: return - # CHECK: } - - filecheck_with_comments(ctx.module) - - -def test_generics_just_args(ctx: MLIRContext): - @func - def mat_product_kernel[M, K, N, dtype]( - A: "T.memref(M, K, dtype)", - B: "T.memref(K, N, dtype)", - C: "T.memref(M, N, dtype)", - ): - one = arith.constant(1.0, dtype) - - mat_product_kernel[32, 32, 32, T.f32()].emit() - - # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%[[VAL_0:.*]]: memref<32x32xf32>, %[[VAL_1:.*]]: memref<32x32xf32>, %[[VAL_2:.*]]: memref<32x32xf32>) { - # CHECK: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32 - # CHECK: return - # CHECK: } - - filecheck_with_comments(ctx.module) - - -def test_generics_closure(ctx: MLIRContext): - - @func - def mat_product_kernel[M, K, N, dtype]( - A: "T.memref(M, K, dtype)", - B: "T.memref(K, N, dtype)", - C: "T.memref(M, N, dtype)", - ): - one = arith.constant(1, dtype) - - mat_product_kernel[32, 32, 32, T.i32()].emit() - mat_product_kernel[32, 32, 32, T.f32()].emit() - - # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_i32(%[[VAL_0:.*]]: memref<32x32xi32>, %[[VAL_1:.*]]: memref<32x32xi32>, %[[VAL_2:.*]]: memref<32x32xi32>) { - # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 - # CHECK: return - # CHECK: } - # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) { - # CHECK: %cst = arith.constant 1.000000e+00 : f32 - # CHECK: return - # CHECK: } - - filecheck_with_comments(ctx.module) - - -def test_generics_callable(ctx: MLIRContext): - _op = TypeVar("_op") - - @func - def mat_product_kernel1[_op](): - one = arith.constant(1, T.f32()) - two = _op(one, one) - - @func - def mat_product_kernel2[_op](): - one = arith.constant(1, T.f32()) - two = _op(one, one) - - mat_product_kernel1[arith.maximumf,].emit() - mat_product_kernel2[arith.minimumf,].emit() - mat_product_kernel2[arith.maximumf,].emit() - - # CHECK: func.func @mat_product_kernel1_function_maximumf() { - # CHECK: %cst = arith.constant 1.000000e+00 : f32 - # CHECK: %0 = arith.maximumf %cst, %cst : f32 - # CHECK: return - # CHECK: } - # CHECK: func.func @mat_product_kernel2_function_minimumf() { - # CHECK: %cst = arith.constant 1.000000e+00 : f32 - # CHECK: %0 = arith.minimumf %cst, %cst : f32 - # CHECK: return - # CHECK: } - # CHECK: func.func @mat_product_kernel2_function_maximumf() { - # CHECK: %cst = arith.constant 1.000000e+00 : f32 - # CHECK: %0 = arith.maximumf %cst, %cst : f32 - # CHECK: return - # CHECK: } - - filecheck_with_comments(ctx.module) - - -def test_generics_with_canonicalizations(ctx: MLIRContext): - - @func - @canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) - def mat_product_kernel[M, K, N, dtype]( - A: "T.memref(M, K, dtype)", - B: "T.memref(K, N, dtype)", - C: "T.memref(M, N, dtype)", - ): - x = arith.constant(1, index=True) - y = arith.constant(1, index=True) - one = arith.constant(1.0, type=dtype) - tmp = arith.constant(0, type=dtype) - for k, tmp, _ in scf.range_(K, iter_args=[tmp]): - tmp += A[x, k] * B[k, y] - tmp = yield tmp - C[x, y] = tmp + one - - mat_product_kernel[32, 32, 32, T.f32()].emit() - - # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%[[VAL_0:.*]]: memref<32x32xf32>, %[[VAL_1:.*]]: memref<32x32xf32>, %[[VAL_2:.*]]: memref<32x32xf32>) { - # CHECK: %[[VAL_3:.*]] = arith.constant 1 : index - # CHECK: %[[VAL_4:.*]] = arith.constant 1 : index - # CHECK: %[[VAL_5:.*]] = arith.constant 1.000000e+00 : f32 - # CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 - # CHECK: %[[VAL_7:.*]] = arith.constant 0 : index - # CHECK: %[[VAL_8:.*]] = arith.constant 32 : index - # CHECK: %[[VAL_9:.*]] = arith.constant 1 : index - # CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_6]]) -> (f32) { - # CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_11]]] : memref<32x32xf32> - # CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_11]], %[[VAL_4]]] : memref<32x32xf32> - # CHECK: %[[VAL_15:.*]] = math.fma %[[VAL_13]], %[[VAL_14]], %[[VAL_12]] : f32 - # CHECK: scf.yield %[[VAL_15]] : f32 - # CHECK: } - # CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_17:.*]], %[[VAL_5]] : f32 - # CHECK: memref.store %[[VAL_16]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_4]]] : memref<32x32xf32> - # CHECK: return - # CHECK: } - - filecheck_with_comments(ctx.module) - - -@pytest.mark.skipif( - sys.version_info < (3, 13) or platform.system() == "Windows", - reason="requires python3.13 or higher (and windows can't find the source file)", -) -def test_generics_assignment(ctx: MLIRContext): - # dodge <3.12 parser that doesn't support square brackets generics - # but also need a real file here because rewriter needs source... - - src = dedent( - """\ - @func - def type_bound[M, K, N: T.i32()]( - A: "T.memref(M, K, T.f32())", - B: "T.memref(K, N, T.f32())", - C: "T.memref(M, N, T.f32())", - ): - x = arith.constant(1, index=True) - y = arith.constant(1, index=True) - - @func - def type_bound_and_default[M, K, N: T.i32() = 10, L: T.f32() = 10.0]( - A: "T.memref(M, K, T.f32())", - B: "T.memref(K, N, T.f32())", - C: "T.memref(M, N, T.f32())", - ): - x = arith.constant(1, index=True) - y = arith.constant(1, index=True) - n = arith.constant(L) - - globals()["type_bound"] = type_bound - globals()["type_bound_and_default"] = type_bound_and_default - """ - ) - - with tempfile.NamedTemporaryFile(mode="w") as tmp: - tmp.write(src) - tmp.flush() - code = compile(src, tmp.name, "exec") - exec(code, globals(), locals()) - - # CHECK: func.func @type_bound_int_32_int_32_i32_10(%arg0: memref<32x32xf32>, %arg1: memref<32x10xf32>, %arg2: memref<32x10xf32>) { - # CHECK: %c1 = arith.constant 1 : index - # CHECK: %c1_0 = arith.constant 1 : index - # CHECK: return - # CHECK: } - type_bound[32, 32, 10].emit() - # CHECK: func.func @type_bound_and_default_int_32_int_32_i32_10_f32_10.0(%arg0: memref<32x32xf32>, %arg1: memref<32x10xf32>, %arg2: memref<32x10xf32>) { - # CHECK: %c1 = arith.constant 1 : index - # CHECK: %c1_0 = arith.constant 1 : index - # CHECK: %cst = arith.constant 1.000000e+01 : f32 - # CHECK: return - # CHECK: } - type_bound_and_default[32, 32].emit() - - def test_raii_mlir_context_module(): ctx = RAIIMLIRContextModule() @@ -385,35 +184,3 @@ def demo_fun1(a, b): # CHECK: } filecheck_with_comments(ctx.module) - - -def test_name_mangling(ctx: MLIRContext): - _S = ShapedType.get_dynamic_size() - - @func - def maxpool2d[ - kernel_size_0, kernel_size_1, stride_0, stride_1, dilation_0, dilation_1 - ]( - input: T.memref(_S, _S, _S, _S, T.f32()), - output: T.memref(_S, _S, _S, _S, T.f32()), - ): - kernel_shape_surrogate = memref.alloca( - (kernel_size_0, kernel_size_1), - T.f32(), - ) - - linalg.pooling_nchw_max( - input, - kernel_shape_surrogate, - output, - strides=[stride_0, stride_1], - dilations=[dilation_0, dilation_1], - ) - - maxpool2d_k = maxpool2d[2, 2, 1, 1, 1, 1].emit() - # CHECK: func.func @maxpool2d_int_2_int_2_int_1_int_1_int_1_int_1(%arg0: memref, %arg1: memref) { - # CHECK: %alloca = memref.alloca() : memref<2x2xf32> - # CHECK: linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %alloca : memref, memref<2x2xf32>) outs(%arg1 : memref) - # CHECK: return - # CHECK: } - filecheck_with_comments(maxpool2d_k) diff --git a/projects/eudsl-python-extras/tests/dialect/test_generics.py b/projects/eudsl-python-extras/tests/dialect/test_generics.py new file mode 100644 index 00000000..142343ff --- /dev/null +++ b/projects/eudsl-python-extras/tests/dialect/test_generics.py @@ -0,0 +1,573 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import sys +from textwrap import dedent +from typing import TypeVar + +import mlir.extras.types as T +import pytest +from mlir.ir import ShapedType + +from mlir.extras.ast.canonicalize import canonicalize +from mlir.extras.ast.py_type import PyTypeVarObject +from mlir.extras.dialects import linalg, arith, scf, memref, gpu +from mlir.extras.dialects.func import func +from mlir.extras.dialects.gpu import ( + set_container_module, + module, +) +from mlir.extras.runtime.passes import run_pipeline, Pipeline + +# noinspection PyUnresolvedReferences +from mlir.extras.testing import ( + mlir_ctx as ctx, + filecheck, + filecheck_with_comments, + MLIRContext, +) + +# needed since the fix isn't defined here nor conftest.py +pytest.mark.usefixtures("ctx") + + +def test_func_no_context_2(ctx: MLIRContext): + @func + def matmul_i32_i32[M, N]( + A: "T.memref(M, N, T.i32())", + B: "T.memref(M, N, T.i32())", + C: "T.memref(M, N, T.i32())", + ): + linalg.matmul(A, B, C) + + matmul_i32_i32[16, 16].emit() + + # CHECK: func.func @matmul_i32_i32_int_16_int_16(%[[VAL_0:.*]]: memref<16x16xi32>, %[[VAL_1:.*]]: memref<16x16xi32>, %[[VAL_2:.*]]: memref<16x16xi32>) { + # CHECK: linalg.matmul {cast = #linalg.type_fn} ins(%[[VAL_0]], %[[VAL_1]] : memref<16x16xi32>, memref<16x16xi32>) outs(%[[VAL_2]] : memref<16x16xi32>) + # CHECK: return + # CHECK: } + + filecheck_with_comments(ctx.module) + + +def test_generics_just_args(ctx: MLIRContext): + @func + def mat_product_kernel[M, K, N, dtype]( + A: "T.memref(M, K, dtype)", + B: "T.memref(K, N, dtype)", + C: "T.memref(M, N, dtype)", + ): + one = arith.constant(1.0, dtype) + + mat_product_kernel[32, 32, 32, T.f32()].emit() + + # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%[[VAL_0:.*]]: memref<32x32xf32>, %[[VAL_1:.*]]: memref<32x32xf32>, %[[VAL_2:.*]]: memref<32x32xf32>) { + # CHECK: %[[VAL_3:.*]] = arith.constant 1.000000e+00 : f32 + # CHECK: return + # CHECK: } + + filecheck_with_comments(ctx.module) + + +def test_generics_closure(ctx: MLIRContext): + @func + def mat_product_kernel[M, K, N, dtype]( + A: "T.memref(M, K, dtype)", + B: "T.memref(K, N, dtype)", + C: "T.memref(M, N, dtype)", + ): + one = arith.constant(1, dtype) + + mat_product_kernel[32, 32, 32, T.i32()].emit() + mat_product_kernel[32, 32, 32, T.f32()].emit() + + # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_i32(%[[VAL_0:.*]]: memref<32x32xi32>, %[[VAL_1:.*]]: memref<32x32xi32>, %[[VAL_2:.*]]: memref<32x32xi32>) { + # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 + # CHECK: return + # CHECK: } + # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: return + # CHECK: } + + filecheck_with_comments(ctx.module) + + +def test_generics_callable(ctx: MLIRContext): + _op = TypeVar("_op") + + @func + def mat_product_kernel1[_op](): + one = arith.constant(1, T.f32()) + two = _op(one, one) + + @func + def mat_product_kernel2[_op](): + one = arith.constant(1, T.f32()) + two = _op(one, one) + + mat_product_kernel1[arith.maximumf,].emit() + mat_product_kernel2[arith.minimumf,].emit() + mat_product_kernel2[arith.maximumf,].emit() + + # CHECK: func.func @mat_product_kernel1_function_maximumf() { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: %0 = arith.maximumf %cst, %cst : f32 + # CHECK: return + # CHECK: } + # CHECK: func.func @mat_product_kernel2_function_minimumf() { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: %0 = arith.minimumf %cst, %cst : f32 + # CHECK: return + # CHECK: } + # CHECK: func.func @mat_product_kernel2_function_maximumf() { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: %0 = arith.maximumf %cst, %cst : f32 + # CHECK: return + # CHECK: } + + filecheck_with_comments(ctx.module) + + +def test_generics_with_canonicalizations(ctx: MLIRContext): + @func + @canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) + def mat_product_kernel[M, K, N, dtype]( + A: "T.memref(M, K, dtype)", + B: "T.memref(K, N, dtype)", + C: "T.memref(M, N, dtype)", + ): + x = arith.constant(1, index=True) + y = arith.constant(1, index=True) + one = arith.constant(1.0, type=dtype) + tmp = arith.constant(0, type=dtype) + for k, tmp, _ in scf.range_(K, iter_args=[tmp]): + tmp += A[x, k] * B[k, y] + tmp = yield tmp + C[x, y] = tmp + one + + mat_product_kernel[32, 32, 32, T.f32()].emit() + + # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%[[VAL_0:.*]]: memref<32x32xf32>, %[[VAL_1:.*]]: memref<32x32xf32>, %[[VAL_2:.*]]: memref<32x32xf32>) { + # CHECK: %[[VAL_3:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_4:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_5:.*]] = arith.constant 1.000000e+00 : f32 + # CHECK: %[[VAL_6:.*]] = arith.constant 0.000000e+00 : f32 + # CHECK: %[[VAL_7:.*]] = arith.constant 0 : index + # CHECK: %[[VAL_8:.*]] = arith.constant 32 : index + # CHECK: %[[VAL_9:.*]] = arith.constant 1 : index + # CHECK: %[[VAL_10:.*]] = scf.for %[[VAL_11:.*]] = %[[VAL_7]] to %[[VAL_8]] step %[[VAL_9]] iter_args(%[[VAL_12:.*]] = %[[VAL_6]]) -> (f32) { + # CHECK: %[[VAL_13:.*]] = memref.load %[[VAL_0]]{{\[}}%[[VAL_3]], %[[VAL_11]]] : memref<32x32xf32> + # CHECK: %[[VAL_14:.*]] = memref.load %[[VAL_1]]{{\[}}%[[VAL_11]], %[[VAL_4]]] : memref<32x32xf32> + # CHECK: %[[VAL_15:.*]] = math.fma %[[VAL_13]], %[[VAL_14]], %[[VAL_12]] : f32 + # CHECK: scf.yield %[[VAL_15]] : f32 + # CHECK: } + # CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_17:.*]], %[[VAL_5]] : f32 + # CHECK: memref.store %[[VAL_16]], %[[VAL_2]]{{\[}}%[[VAL_3]], %[[VAL_4]]] : memref<32x32xf32> + # CHECK: return + # CHECK: } + + filecheck_with_comments(ctx.module) + + +def test_generics_assignment(ctx: MLIRContext): + @func + def type_bound[M, K, N: T.i32()]( + A: "T.memref(M, K, T.f32())", + B: "T.memref(K, N, T.f32())", + C: "T.memref(M, N, T.f32())", + ): + x = arith.constant(1, index=True) + y = arith.constant(1, index=True) + + @func + def type_bound_and_default[M, K, N: T.i32() = 10, L: T.f32() = 10.0]( + A: "T.memref(M, K, T.f32())", + B: "T.memref(K, N, T.f32())", + C: "T.memref(M, N, T.f32())", + ): + x = arith.constant(1, index=True) + y = arith.constant(1, index=True) + n = arith.constant(L) + + # CHECK: func.func @type_bound_int_32_int_32_i32_10(%arg0: memref<32x32xf32>, %arg1: memref<32x10xf32>, %arg2: memref<32x10xf32>) { + # CHECK: %c1 = arith.constant 1 : index + # CHECK: %c1_0 = arith.constant 1 : index + # CHECK: return + # CHECK: } + type_bound[32, 32, 10].emit() + # CHECK: func.func @type_bound_and_default_int_32_int_32_i32_10_f32_10.0(%arg0: memref<32x32xf32>, %arg1: memref<32x10xf32>, %arg2: memref<32x10xf32>) { + # CHECK: %c1 = arith.constant 1 : index + # CHECK: %c1_0 = arith.constant 1 : index + # CHECK: %cst = arith.constant 1.000000e+01 : f32 + # CHECK: return + # CHECK: } + type_bound_and_default[32, 32].emit() + + +def test_name_mangling(ctx: MLIRContext): + _S = ShapedType.get_dynamic_size() + + @func + def maxpool2d[ + kernel_size_0, kernel_size_1, stride_0, stride_1, dilation_0, dilation_1 + ]( + input: T.memref(_S, _S, _S, _S, T.f32()), + output: T.memref(_S, _S, _S, _S, T.f32()), + ): + kernel_shape_surrogate = memref.alloca( + (kernel_size_0, kernel_size_1), + T.f32(), + ) + + linalg.pooling_nchw_max( + input, + kernel_shape_surrogate, + output, + strides=[stride_0, stride_1], + dilations=[dilation_0, dilation_1], + ) + + maxpool2d_k = maxpool2d[2, 2, 1, 1, 1, 1].emit() + # CHECK: func.func @maxpool2d_int_2_int_2_int_1_int_1_int_1_int_1(%arg0: memref, %arg1: memref) { + # CHECK: %alloca = memref.alloca() : memref<2x2xf32> + # CHECK: linalg.pooling_nchw_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %alloca : memref, memref<2x2xf32>) outs(%arg1 : memref) + # CHECK: return + # CHECK: } + filecheck_with_comments(maxpool2d_k) + + +@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher") +def test_generics(ctx: MLIRContext): + set_container_module(ctx.module) + + @gpu.func + def mat_product_kernel[M, K, N, dtype]( + A: "T.memref(M, K, dtype)", + B: "T.memref(K, N, dtype)", + C: "T.memref(M, N, dtype)", + ): + x = gpu.block_dim.x * gpu.block_idx.x + gpu.thread_idx.x + y = gpu.block_dim.y * gpu.block_idx.y + gpu.thread_idx.y + + one = arith.constant(1.0, type=dtype) + tmp = arith.constant(0, type=dtype) + for k, tmp, _ in scf.range_(K, iter_args=[tmp]): + tmp += A[x, k] * B[k, y] + tmp = scf.yield_(tmp) + C[x, y] = tmp + one + + @module("naive", ["#nvvm.target"]) + def _(): + mat_product_kernel[32, 32, 32, T.f32()].emit() # noqa: F821 + + correct = dedent( + """\ + module attributes {gpu.container_module} { + gpu.module @naive [#nvvm.target] { + gpu.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) kernel { + %block_dim_x = gpu.block_dim x + %block_id_x = gpu.block_id x + %0 = arith.muli %block_dim_x, %block_id_x : index + %thread_id_x = gpu.thread_id x + %1 = arith.addi %0, %thread_id_x : index + %block_dim_y = gpu.block_dim y + %block_id_y = gpu.block_id y + %2 = arith.muli %block_dim_y, %block_id_y : index + %thread_id_y = gpu.thread_id y + %3 = arith.addi %2, %thread_id_y : index + %cst = arith.constant 1.000000e+00 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %4 = scf.for %arg3 = %c0 to %c32 step %c1 iter_args(%arg4 = %cst_0) -> (f32) { + %6 = memref.load %arg0[%1, %arg3] : memref<32x32xf32> + %7 = memref.load %arg1[%arg3, %3] : memref<32x32xf32> + %8 = arith.mulf %6, %7 : f32 + %9 = arith.addf %arg4, %8 : f32 + scf.yield %9 : f32 + } + %5 = arith.addf %4, %cst : f32 + memref.store %5, %arg2[%1, %3] : memref<32x32xf32> + gpu.return + } + } + } + """ + ) + + filecheck(correct, ctx.module) + + +def test_generic_type_var_closure_patching(ctx: MLIRContext): + def fun2[foo, bar, A: foo + bar](): + print(A.__bound__) + + A_type_param = fun2.__type_params__[2] + + a = PyTypeVarObject.from_object(A_type_param) + a_something = a.bound.contents.into_object() + a_something.__closure__[0].cell_contents = 5 + a_something.__closure__[1].cell_contents = 7 + + fun2() + + +def test_generic_type_var_closure_patching_dependent_generics(ctx: MLIRContext): + @gpu.func + def test_plain[M, K, N, dtype, A_t = T.memref( + M, K, dtype + ), B_t = T.memref(K, N, dtype), C_t = T.memref(M, N, dtype),]( + A: A_t, B: B_t, C: C_t + ): + one = arith.constant(1.0, type=dtype) + + @gpu.func + @canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) + def test_2_with_rewrite[M, K, N, dtype, A_t = T.memref( + M, K, dtype + ), B_t = T.memref(K, N, dtype), C_t = T.memref(M, N, dtype),]( + A: A_t, B: B_t, C: C_t + ): + one = arith.constant(1.0, type=dtype) + + @module("mod1", ["#nvvm.target"]) + def _(): + test_plain[1, 2, 3, T.f32()].emit() # noqa: F821 + test_2_with_rewrite[1, 2, 3, T.f32()].emit() # noqa: F821 + + @module("mod2", ["#nvvm.target"]) + def _(): + test_plain[4, 5, 6, T.f16()].emit() # noqa: F821 + test_2_with_rewrite[4, 5, 6, T.f16()].emit() # noqa: F821 + + # CHECK: gpu.module @mod1 [#nvvm.target] { + # CHECK: gpu.func @"test_plain_int_1_int_2_int_3_type_f32_MemRefType_memref<1x2xf32>_MemRefType_memref<2x3xf32>_MemRefType_memref<1x3xf32>"(%arg0: memref<1x2xf32>, %arg1: memref<2x3xf32>, %arg2: memref<1x3xf32>) kernel { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: gpu.return + # CHECK: } + # CHECK: gpu.func @"test_2_with_rewrite_int_1_int_2_int_3_type_f32_MemRefType_memref<1x2xf32>_MemRefType_memref<2x3xf32>_MemRefType_memref<1x3xf32>"(%arg0: memref<1x2xf32>, %arg1: memref<2x3xf32>, %arg2: memref<1x3xf32>) kernel { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: gpu.return + # CHECK: } + # CHECK: } + # CHECK: gpu.module @mod2 [#nvvm.target] { + # CHECK: gpu.func @"test_plain_int_4_int_5_int_6_type_f16_MemRefType_memref<4x5xf16>_MemRefType_memref<5x6xf16>_MemRefType_memref<4x6xf16>"(%arg0: memref<4x5xf16>, %arg1: memref<5x6xf16>, %arg2: memref<4x6xf16>) kernel { + # CHECK: %cst = arith.constant 1.000000e+00 : f16 + # CHECK: gpu.return + # CHECK: } + # CHECK: gpu.func @"test_2_with_rewrite_int_4_int_5_int_6_type_f16_MemRefType_memref<4x5xf16>_MemRefType_memref<5x6xf16>_MemRefType_memref<4x6xf16>"(%arg0: memref<4x5xf16>, %arg1: memref<5x6xf16>, %arg2: memref<4x6xf16>) kernel { + # CHECK: %cst = arith.constant 1.000000e+00 : f16 + # CHECK: gpu.return + # CHECK: } + # CHECK: } + filecheck_with_comments(ctx.module) + + +def test_pooling_nchw_max(ctx: MLIRContext): + S = ShapedType.get_dynamic_size() + + @func + def maxpool2d[ + kernel_size_0, kernel_size_1, stride_0, stride_1, dilation_0, dilation_1 + ]( + input: T.memref(S, S, S, S, T.f32()), + output: T.memref(S, S, S, S, T.f32()), + ): + kernel_shape_surrogate = memref.alloca( + (kernel_size_0, kernel_size_1), + T.f32(), + ) + + linalg.pooling_nchw_max( + input, + kernel_shape_surrogate, + output, + strides=[stride_0, stride_1], + dilations=[dilation_0, dilation_1], + ) + + kernel_sizes = [2, 3] + strides = [4, 5] + dilations = [6, 7] + maxpool2d_k = maxpool2d[ + kernel_sizes[0], + kernel_sizes[1], + strides[0], + strides[1], + dilations[0], + dilations[1], + ].emit() + module = run_pipeline( + ctx.module, + Pipeline().bufferize().Func(Pipeline().convert_linalg_to_parallel_loops()), + ) + # CHECK: func.func @maxpool2d_int_2_int_3_int_4_int_5_int_6_int_7(%arg0: memref, %arg1: memref) { + # CHECK: %c3 = arith.constant 3 : index + # CHECK: %c2 = arith.constant 2 : index + # CHECK: %c1 = arith.constant 1 : index + # CHECK: %c0 = arith.constant 0 : index + # CHECK: %dim = memref.dim %arg0, %c0 : memref + # CHECK: %dim_0 = memref.dim %arg0, %c1 : memref + # CHECK: %dim_1 = memref.dim %arg1, %c2 : memref + # CHECK: %dim_2 = memref.dim %arg1, %c3 : memref + # CHECK: scf.parallel (%arg2, %arg3, %arg4, %arg5) = (%c0, %c0, %c0, %c0) to (%dim, %dim_0, %dim_1, %dim_2) step (%c1, %c1, %c1, %c1) { + # CHECK: scf.for %arg6 = %c0 to %c2 step %c1 { + # CHECK: scf.for %arg7 = %c0 to %c3 step %c1 { + # CHECK: %0 = affine.apply #map(%arg4, %arg6) + # CHECK: %1 = affine.apply #map1(%arg5, %arg7) + # CHECK: %2 = memref.load %arg0[%arg2, %arg3, %0, %1] : memref + # CHECK: %3 = memref.load %arg1[%arg2, %arg3, %arg4, %arg5] : memref + # CHECK: %4 = arith.maximumf %3, %2 : f32 + # CHECK: memref.store %4, %arg1[%arg2, %arg3, %arg4, %arg5] : memref + # CHECK: } + # CHECK: } + # CHECK: scf.reduce + # CHECK: } + # CHECK: return + # CHECK: } + filecheck_with_comments(module) + + +def test_pooling_ncdhw_max(ctx: MLIRContext): + S = ShapedType.get_dynamic_size() + + @func + def maxpool3d[ + kernel_size_0, + kernel_size_1, + kernel_size_2, + stride_0, + stride_1, + stride_2, + dilation_0, + dilation_1, + dilation_2, + ]( + input: T.memref(S, S, S, S, S, T.f32()), + output: T.memref(S, S, S, S, S, T.f32()), + ): + kernel_shape_surrogate = memref.alloca( + (kernel_size_0, kernel_size_1, kernel_size_2), + T.f32(), + ) + + linalg.pooling_ncdhw_max( + input, + kernel_shape_surrogate, + output, + strides=[stride_0, stride_1, stride_2], + dilations=[dilation_0, dilation_1, dilation_2], + ) + + kernel_sizes = [1, 2, 3] + strides = [5, 6, 7] + dilations = [7, 8, 9] + maxpool3d_k = maxpool3d[ + kernel_sizes[0], + kernel_sizes[1], + kernel_sizes[2], + strides[0], + strides[1], + strides[2], + dilations[0], + dilations[1], + dilations[2], + ].emit() + # CHECK: func.func @maxpool3d_int_1_int_2_int_3_int_5_int_6_int_7_int_7_int_8_int_9(%arg0: memref, %arg1: memref) { + # CHECK: %alloca = memref.alloca() : memref<1x2x3xf32> + # CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2 * 5 + d5 * 7, d3 * 6 + d6 * 8, d4 * 7 + d7 * 9)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %alloca : memref, memref<1x2x3xf32>) outs(%arg1 : memref) { + # CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): + # CHECK: %0 = arith.maximumf %in, %out : f32 + # CHECK: linalg.yield %0 : f32 + # CHECK: } + # CHECK: return + # CHECK: } + filecheck_with_comments(maxpool3d_k) + + +def test_pooling_ncdhw_max_parallel(ctx: MLIRContext): + S = ShapedType.get_dynamic_size() + + @func + def maxpool3d[ + kernel_size_0, + kernel_size_1, + kernel_size_2, + stride_0, + stride_1, + stride_2, + dilation_0, + dilation_1, + dilation_2, + ]( + input: T.memref(S, S, S, S, S, T.f32()), + output: T.memref(S, S, S, S, S, T.f32()), + ): + kernel_shape_surrogate = memref.alloca( + (kernel_size_0, kernel_size_1, kernel_size_2), + T.f32(), + ) + + linalg.pooling_ncdhw_max( + input, + kernel_shape_surrogate, + output, + strides=[stride_0, stride_1, stride_2], + dilations=[dilation_0, dilation_1, dilation_2], + ) + + kernel_sizes = [1, 2, 3] + strides = [4, 5, 6] + dilations = [7, 8, 9] + maxpool3d_k = maxpool3d[ + kernel_sizes[0], + kernel_sizes[1], + kernel_sizes[2], + strides[0], + strides[1], + strides[2], + dilations[0], + dilations[1], + dilations[2], + ].emit() + module = run_pipeline( + ctx.module, + Pipeline().bufferize().Func(Pipeline().convert_linalg_to_parallel_loops()), + ) + # CHECK: #map = affine_map<(d0, d1) -> (d0 * 4 + d1 * 7)> + # CHECK: #map1 = affine_map<(d0, d1) -> (d0 * 5 + d1 * 8)> + # CHECK: #map2 = affine_map<(d0, d1) -> (d0 * 6 + d1 * 9)> + # CHECK: module { + # CHECK: func.func @maxpool3d_int_1_int_2_int_3_int_4_int_5_int_6_int_7_int_8_int_9(%arg0: memref, %arg1: memref) { + # CHECK: %c4 = arith.constant 4 : index + # CHECK: %c3 = arith.constant 3 : index + # CHECK: %c2 = arith.constant 2 : index + # CHECK: %c1 = arith.constant 1 : index + # CHECK: %c0 = arith.constant 0 : index + # CHECK: %dim = memref.dim %arg0, %c0 : memref + # CHECK: %dim_0 = memref.dim %arg0, %c1 : memref + # CHECK: %dim_1 = memref.dim %arg1, %c2 : memref + # CHECK: %dim_2 = memref.dim %arg1, %c3 : memref + # CHECK: %dim_3 = memref.dim %arg1, %c4 : memref + # CHECK: scf.parallel (%arg2, %arg3, %arg4, %arg5, %arg6) = (%c0, %c0, %c0, %c0, %c0) to (%dim, %dim_0, %dim_1, %dim_2, %dim_3) step (%c1, %c1, %c1, %c1, %c1) { + # CHECK: scf.for %arg7 = %c0 to %c1 step %c1 { + # CHECK: scf.for %arg8 = %c0 to %c2 step %c1 { + # CHECK: scf.for %arg9 = %c0 to %c3 step %c1 { + # CHECK: %0 = affine.apply #map(%arg4, %arg7) + # CHECK: %1 = affine.apply #map1(%arg5, %arg8) + # CHECK: %2 = affine.apply #map2(%arg6, %arg9) + # CHECK: %3 = memref.load %arg0[%arg2, %arg3, %0, %1, %2] : memref + # CHECK: %4 = memref.load %arg1[%arg2, %arg3, %arg4, %arg5, %arg6] : memref + # CHECK: %5 = arith.maximumf %3, %4 : f32 + # CHECK: memref.store %5, %arg1[%arg2, %arg3, %arg4, %arg5, %arg6] : memref + # CHECK: } + # CHECK: } + # CHECK: } + # CHECK: scf.reduce + # CHECK: } + # CHECK: return + # CHECK: } + # CHECK: } + filecheck_with_comments(module) diff --git a/projects/eudsl-python-extras/tests/dialect/test_gpu.py b/projects/eudsl-python-extras/tests/dialect/test_gpu.py index 66d8c165..5cede87a 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_gpu.py +++ b/projects/eudsl-python-extras/tests/dialect/test_gpu.py @@ -1,12 +1,9 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -import platform import random import sys -import tempfile import time -from textwrap import dedent import mlir.extras.types as T import numpy as np @@ -30,7 +27,6 @@ thread_idx, block_dim, GPUModuleMeta, - func as gpu_func, set_container_module, launch, all_reduce_, @@ -119,7 +115,7 @@ def test_class(ctx: MLIRContext): M, N, K = 4 * scale, 16 * scale, 8 * scale class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]): - @gpu_func(emit=True) + @gpu.func(emit=True) @canonicalize(using=scf.canonicalizer) def mat_product_kernel( A: T.memref(M, N, T.f32()), @@ -156,7 +152,7 @@ def test_class_call(ctx: MLIRContext): set_container_module(ctx.module) class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]): - @gpu_func(emit=True, emit_grid=True) + @gpu.func(emit=True, emit_grid=True) @canonicalize(using=scf.canonicalizer) def mat_product_kernel( A: T.memref(M, N, T.f32()), @@ -214,7 +210,7 @@ def test_class_call_from_func(ctx: MLIRContext): set_container_module(ctx.module) class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]): - @gpu_func(emit=True, emit_grid=True) + @gpu.func(emit=True, emit_grid=True) @canonicalize(using=scf.canonicalizer) def mat_product_kernel( A: T.memref(M, N, T.f32()), @@ -283,7 +279,7 @@ def test_async_object(ctx: MLIRContext): set_container_module(ctx.module) class MyClass1(metaclass=GPUModuleMeta, targets=["#nvvm.target"]): - @gpu_func(emit=True, emit_grid=True) + @gpu.func(emit=True, emit_grid=True) @canonicalize(using=scf.canonicalizer) def mat_product_kernel( A: T.memref(M, N, T.f32()), @@ -524,195 +520,12 @@ def kernel(bx, by, bz, tx, ty, tz, *grid_block_sizes): filecheck_with_comments(module) -@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher") -def test_generics(ctx: MLIRContext): - set_container_module(ctx.module) - - # dodge <3.12 parser that doesn't support square brackets generics - exec( - dedent( - """\ - @gpu_func - def mat_product_kernel[ - M, K, N, dtype - ]( - A: "T.memref(M, K, dtype)", - B: "T.memref(K, N, dtype)", - C: "T.memref(M, N, dtype)", - ): - x = block_dim.x * block_idx.x + thread_idx.x - y = block_dim.y * block_idx.y + thread_idx.y - - one = arith.constant(1.0, type=dtype) - tmp = arith.constant(0, type=dtype) - for k, tmp, _ in scf.range_(K, iter_args=[tmp]): - tmp += A[x, k] * B[k, y] - tmp = scf.yield_(tmp) - C[x, y] = tmp + one - - globals()["mat_product_kernel"] = mat_product_kernel - """ - ) - ) - - @module("naive", ["#nvvm.target"]) - def _(): - mat_product_kernel[32, 32, 32, T.f32()].emit() # noqa: F821 - - correct = dedent( - """\ - module attributes {gpu.container_module} { - gpu.module @naive [#nvvm.target] { - gpu.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) kernel { - %block_dim_x = gpu.block_dim x - %block_id_x = gpu.block_id x - %0 = arith.muli %block_dim_x, %block_id_x : index - %thread_id_x = gpu.thread_id x - %1 = arith.addi %0, %thread_id_x : index - %block_dim_y = gpu.block_dim y - %block_id_y = gpu.block_id y - %2 = arith.muli %block_dim_y, %block_id_y : index - %thread_id_y = gpu.thread_id y - %3 = arith.addi %2, %thread_id_y : index - %cst = arith.constant 1.000000e+00 : f32 - %cst_0 = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c32 = arith.constant 32 : index - %c1 = arith.constant 1 : index - %4 = scf.for %arg3 = %c0 to %c32 step %c1 iter_args(%arg4 = %cst_0) -> (f32) { - %6 = memref.load %arg0[%1, %arg3] : memref<32x32xf32> - %7 = memref.load %arg1[%arg3, %3] : memref<32x32xf32> - %8 = arith.mulf %6, %7 : f32 - %9 = arith.addf %arg4, %8 : f32 - scf.yield %9 : f32 - } - %5 = arith.addf %4, %cst : f32 - memref.store %5, %arg2[%1, %3] : memref<32x32xf32> - gpu.return - } - } - } - """ - ) - print(ctx.module) - - filecheck(correct, ctx.module) - - -@pytest.mark.skipif(sys.version_info < (3, 12), reason="requires python3.12 or higher") -def test_generic_type_var_closure_patching(ctx: MLIRContext): - # dodge <3.12 parser that doesn't support square brackets generics - exec( - dedent( - """\ - from mlir.extras.ast.py_type import PyTypeVarObject - - def fun2[foo, bar, A: foo + bar](): - print(A.__bound__) - - - A_type_param = fun2.__type_params__[2] - - - a = PyTypeVarObject.from_object(A_type_param) - a_something = a.bound.contents.into_object() - a_something.__closure__[0].cell_contents = 5 - a_something.__closure__[1].cell_contents = 7 - - fun2() - """ - ) - ) - - -@pytest.mark.skipif( - sys.version_info < (3, 13) or platform.system() == "Windows", - reason="requires python3.13 or higher (and windows can't find the source file)", -) -def test_generic_type_var_closure_patching_dependent_generics(ctx: MLIRContext): - # dodge <3.12 parser that doesn't support square brackets generics - # but also need a real file here because rewriter needs source... - src = dedent( - """\ - from mlir.extras.dialects import arith, gpu, scf - from mlir.extras.ast.canonicalize import canonicalize - import mlir.extras.types as T - - @gpu.func - def test_plain[ - M, - K, - N, - dtype, - A_t = T.memref(M, K, dtype), - B_t = T.memref(K, N, dtype), - C_t = T.memref(M, N, dtype), - ](A: A_t, B: B_t, C: C_t): - one = arith.constant(1.0, type=dtype) - - @gpu.func - @canonicalize(using=(arith.canonicalizer, scf.canonicalizer)) - def test_2_with_rewrite[ - M, - K, - N, - dtype, - A_t = T.memref(M, K, dtype), - B_t = T.memref(K, N, dtype), - C_t = T.memref(M, N, dtype), - ](A: A_t, B: B_t, C: C_t): - one = arith.constant(1.0, type=dtype) - - globals()["test_plain"] = test_plain - globals()["test_2_with_rewrite"] = test_2_with_rewrite - """ - ) - - with tempfile.NamedTemporaryFile(mode="w") as tmp: - tmp.write(src) - tmp.flush() - code = compile(src, tmp.name, "exec") - exec(code, globals(), locals()) - - @module("mod1", ["#nvvm.target"]) - def _(): - test_plain[1, 2, 3, T.f32()].emit() # noqa: F821 - test_2_with_rewrite[1, 2, 3, T.f32()].emit() # noqa: F821 - - @module("mod2", ["#nvvm.target"]) - def _(): - test_plain[4, 5, 6, T.f16()].emit() # noqa: F821 - test_2_with_rewrite[4, 5, 6, T.f16()].emit() # noqa: F821 - - # CHECK: gpu.module @mod1 [#nvvm.target] { - # CHECK: gpu.func @"test_plain_int_1_int_2_int_3_type_f32_MemRefType_memref<1x2xf32>_MemRefType_memref<2x3xf32>_MemRefType_memref<1x3xf32>"(%arg0: memref<1x2xf32>, %arg1: memref<2x3xf32>, %arg2: memref<1x3xf32>) kernel { - # CHECK: %cst = arith.constant 1.000000e+00 : f32 - # CHECK: gpu.return - # CHECK: } - # CHECK: gpu.func @"test_2_with_rewrite_int_1_int_2_int_3_type_f32_MemRefType_memref<1x2xf32>_MemRefType_memref<2x3xf32>_MemRefType_memref<1x3xf32>"(%arg0: memref<1x2xf32>, %arg1: memref<2x3xf32>, %arg2: memref<1x3xf32>) kernel { - # CHECK: %cst = arith.constant 1.000000e+00 : f32 - # CHECK: gpu.return - # CHECK: } - # CHECK: } - # CHECK: gpu.module @mod2 [#nvvm.target] { - # CHECK: gpu.func @"test_plain_int_4_int_5_int_6_type_f16_MemRefType_memref<4x5xf16>_MemRefType_memref<5x6xf16>_MemRefType_memref<4x6xf16>"(%arg0: memref<4x5xf16>, %arg1: memref<5x6xf16>, %arg2: memref<4x6xf16>) kernel { - # CHECK: %cst = arith.constant 1.000000e+00 : f16 - # CHECK: gpu.return - # CHECK: } - # CHECK: gpu.func @"test_2_with_rewrite_int_4_int_5_int_6_type_f16_MemRefType_memref<4x5xf16>_MemRefType_memref<5x6xf16>_MemRefType_memref<4x6xf16>"(%arg0: memref<4x5xf16>, %arg1: memref<5x6xf16>, %arg2: memref<4x6xf16>) kernel { - # CHECK: %cst = arith.constant 1.000000e+00 : f16 - # CHECK: gpu.return - # CHECK: } - # CHECK: } - filecheck_with_comments(ctx.module) - - def test_amdgpu(ctx: MLIRContext): set_container_module(ctx.module) M, K, N, dtype = 32, 32, 32, T.f32() - @gpu_func + @gpu.func def mat_product_kernel( A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype) ): @@ -824,7 +637,7 @@ def test_amdgpu_square(ctx: MLIRContext): scale = 1024 M, K, N, dtype = scale, scale, scale, T.f32() - @gpu_func + @gpu.func def mat_product_kernel( A: T.memref(M, K, dtype), B: T.memref(K, N, dtype), C: T.memref(M, N, dtype) ): @@ -938,7 +751,7 @@ def test_amdgpu_vector(ctx: MLIRContext): tz_a, tz_b, tz_c = [2, 2, 2] v2f32 = T.vector(2, T.f32()) - @gpu_func + @gpu.func def smol_matmul( A: T.memref(M, K, T.f32()), B: T.memref(K, N, T.f32()), @@ -1063,13 +876,13 @@ def test_amdgpu_bank_conflicts(ctx: MLIRContext): M = 128 - @gpu_func + @gpu.func def no_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())): for i in range(M): a = A[i, thread_idx.x] B[i, thread_idx.x] = a * a - @gpu_func + @gpu.func def all_bank_conflicts(A: T.memref(M, M, T.f32()), B: T.memref(M, M, T.f32())): for i in range(M): a = A[i, thread_idx.x] @@ -1180,7 +993,7 @@ def test_amdgpu_vector_wmma(ctx: MLIRContext): M, K, N = v_len, v_len, v_len v16f16 = T.vector(v_len, T.f16()) - @gpu_func + @gpu.func @canonicalize(using=scf.canonicalizer) def smol_matmul( a: T.memref(M, K, T.f16()), diff --git a/projects/eudsl-python-extras/tests/dialect/test_linalg.py b/projects/eudsl-python-extras/tests/dialect/test_linalg.py index 9b20aa5b..8836668b 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_linalg.py +++ b/projects/eudsl-python-extras/tests/dialect/test_linalg.py @@ -6,9 +6,6 @@ import pytest from mlir.extras.dialects import linalg, memref, tensor -from mlir.ir import ShapedType - -from mlir.extras.dialects.func import func # noinspection PyUnresolvedReferences from mlir.extras.testing import ( @@ -17,7 +14,6 @@ filecheck_with_comments, mlir_ctx as ctx, ) -from mlir.extras.runtime.passes import Pipeline, run_pipeline # needed since the fix isn't defined here nor conftest.py pytest.mark.usefixtures("ctx") @@ -48,211 +44,3 @@ def test_np_constructor(ctx: MLIRContext): # CHECK: %[[VAL_11:.*]] = linalg.fill ins(%[[VAL_10]] : i32) outs(%[[VAL_5]] : tensor<10x10xi32>) -> tensor<10x10xi32> filecheck_with_comments(ctx.module) - - -def test_pooling_nchw_max(ctx: MLIRContext): - S = ShapedType.get_dynamic_size() - - @func - def maxpool2d[ - kernel_size_0, kernel_size_1, stride_0, stride_1, dilation_0, dilation_1 - ]( - input: T.memref(S, S, S, S, T.f32()), - output: T.memref(S, S, S, S, T.f32()), - ): - kernel_shape_surrogate = memref.alloca( - (kernel_size_0, kernel_size_1), - T.f32(), - ) - - linalg.pooling_nchw_max( - input, - kernel_shape_surrogate, - output, - strides=[stride_0, stride_1], - dilations=[dilation_0, dilation_1], - ) - - kernel_sizes = [2, 3] - strides = [4, 5] - dilations = [6, 7] - maxpool2d_k = maxpool2d[ - kernel_sizes[0], - kernel_sizes[1], - strides[0], - strides[1], - dilations[0], - dilations[1], - ].emit() - module = run_pipeline( - ctx.module, - Pipeline().bufferize().Func(Pipeline().convert_linalg_to_parallel_loops()), - ) - # CHECK: func.func @maxpool2d_int_2_int_3_int_4_int_5_int_6_int_7(%arg0: memref, %arg1: memref) { - # CHECK: %c3 = arith.constant 3 : index - # CHECK: %c2 = arith.constant 2 : index - # CHECK: %c1 = arith.constant 1 : index - # CHECK: %c0 = arith.constant 0 : index - # CHECK: %dim = memref.dim %arg0, %c0 : memref - # CHECK: %dim_0 = memref.dim %arg0, %c1 : memref - # CHECK: %dim_1 = memref.dim %arg1, %c2 : memref - # CHECK: %dim_2 = memref.dim %arg1, %c3 : memref - # CHECK: scf.parallel (%arg2, %arg3, %arg4, %arg5) = (%c0, %c0, %c0, %c0) to (%dim, %dim_0, %dim_1, %dim_2) step (%c1, %c1, %c1, %c1) { - # CHECK: scf.for %arg6 = %c0 to %c2 step %c1 { - # CHECK: scf.for %arg7 = %c0 to %c3 step %c1 { - # CHECK: %0 = affine.apply #map(%arg4, %arg6) - # CHECK: %1 = affine.apply #map1(%arg5, %arg7) - # CHECK: %2 = memref.load %arg0[%arg2, %arg3, %0, %1] : memref - # CHECK: %3 = memref.load %arg1[%arg2, %arg3, %arg4, %arg5] : memref - # CHECK: %4 = arith.maximumf %3, %2 : f32 - # CHECK: memref.store %4, %arg1[%arg2, %arg3, %arg4, %arg5] : memref - # CHECK: } - # CHECK: } - # CHECK: scf.reduce - # CHECK: } - # CHECK: return - # CHECK: } - filecheck_with_comments(module) - - -def test_pooling_ncdhw_max(ctx: MLIRContext): - S = ShapedType.get_dynamic_size() - - @func - def maxpool3d[ - kernel_size_0, - kernel_size_1, - kernel_size_2, - stride_0, - stride_1, - stride_2, - dilation_0, - dilation_1, - dilation_2, - ]( - input: T.memref(S, S, S, S, S, T.f32()), - output: T.memref(S, S, S, S, S, T.f32()), - ): - kernel_shape_surrogate = memref.alloca( - (kernel_size_0, kernel_size_1, kernel_size_2), - T.f32(), - ) - - linalg.pooling_ncdhw_max( - input, - kernel_shape_surrogate, - output, - strides=[stride_0, stride_1, stride_2], - dilations=[dilation_0, dilation_1, dilation_2], - ) - - kernel_sizes = [1, 2, 3] - strides = [5, 6, 7] - dilations = [7, 8, 9] - maxpool3d_k = maxpool3d[ - kernel_sizes[0], - kernel_sizes[1], - kernel_sizes[2], - strides[0], - strides[1], - strides[2], - dilations[0], - dilations[1], - dilations[2], - ].emit() - # CHECK: func.func @maxpool3d_int_1_int_2_int_3_int_5_int_6_int_7_int_7_int_8_int_9(%arg0: memref, %arg1: memref) { - # CHECK: %alloca = memref.alloca() : memref<1x2x3xf32> - # CHECK: linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2 * 5 + d5 * 7, d3 * 6 + d6 * 8, d4 * 7 + d7 * 9)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>, affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%arg0, %alloca : memref, memref<1x2x3xf32>) outs(%arg1 : memref) { - # CHECK: ^bb0(%in: f32, %in_0: f32, %out: f32): - # CHECK: %0 = arith.maximumf %in, %out : f32 - # CHECK: linalg.yield %0 : f32 - # CHECK: } - # CHECK: return - # CHECK: } - filecheck_with_comments(maxpool3d_k) - - -def test_pooling_ncdhw_max_parallel(ctx: MLIRContext): - S = ShapedType.get_dynamic_size() - - @func - def maxpool3d[ - kernel_size_0, - kernel_size_1, - kernel_size_2, - stride_0, - stride_1, - stride_2, - dilation_0, - dilation_1, - dilation_2, - ]( - input: T.memref(S, S, S, S, S, T.f32()), - output: T.memref(S, S, S, S, S, T.f32()), - ): - kernel_shape_surrogate = memref.alloca( - (kernel_size_0, kernel_size_1, kernel_size_2), - T.f32(), - ) - - linalg.pooling_ncdhw_max( - input, - kernel_shape_surrogate, - output, - strides=[stride_0, stride_1, stride_2], - dilations=[dilation_0, dilation_1, dilation_2], - ) - - kernel_sizes = [1, 2, 3] - strides = [4, 5, 6] - dilations = [7, 8, 9] - maxpool3d_k = maxpool3d[ - kernel_sizes[0], - kernel_sizes[1], - kernel_sizes[2], - strides[0], - strides[1], - strides[2], - dilations[0], - dilations[1], - dilations[2], - ].emit() - module = run_pipeline( - ctx.module, - Pipeline().bufferize().Func(Pipeline().convert_linalg_to_parallel_loops()), - ) - # CHECK: #map = affine_map<(d0, d1) -> (d0 * 4 + d1 * 7)> - # CHECK: #map1 = affine_map<(d0, d1) -> (d0 * 5 + d1 * 8)> - # CHECK: #map2 = affine_map<(d0, d1) -> (d0 * 6 + d1 * 9)> - # CHECK: module { - # CHECK: func.func @maxpool3d_int_1_int_2_int_3_int_4_int_5_int_6_int_7_int_8_int_9(%arg0: memref, %arg1: memref) { - # CHECK: %c4 = arith.constant 4 : index - # CHECK: %c3 = arith.constant 3 : index - # CHECK: %c2 = arith.constant 2 : index - # CHECK: %c1 = arith.constant 1 : index - # CHECK: %c0 = arith.constant 0 : index - # CHECK: %dim = memref.dim %arg0, %c0 : memref - # CHECK: %dim_0 = memref.dim %arg0, %c1 : memref - # CHECK: %dim_1 = memref.dim %arg1, %c2 : memref - # CHECK: %dim_2 = memref.dim %arg1, %c3 : memref - # CHECK: %dim_3 = memref.dim %arg1, %c4 : memref - # CHECK: scf.parallel (%arg2, %arg3, %arg4, %arg5, %arg6) = (%c0, %c0, %c0, %c0, %c0) to (%dim, %dim_0, %dim_1, %dim_2, %dim_3) step (%c1, %c1, %c1, %c1, %c1) { - # CHECK: scf.for %arg7 = %c0 to %c1 step %c1 { - # CHECK: scf.for %arg8 = %c0 to %c2 step %c1 { - # CHECK: scf.for %arg9 = %c0 to %c3 step %c1 { - # CHECK: %0 = affine.apply #map(%arg4, %arg7) - # CHECK: %1 = affine.apply #map1(%arg5, %arg8) - # CHECK: %2 = affine.apply #map2(%arg6, %arg9) - # CHECK: %3 = memref.load %arg0[%arg2, %arg3, %0, %1, %2] : memref - # CHECK: %4 = memref.load %arg1[%arg2, %arg3, %arg4, %arg5, %arg6] : memref - # CHECK: %5 = arith.maximumf %3, %4 : f32 - # CHECK: memref.store %5, %arg1[%arg2, %arg3, %arg4, %arg5, %arg6] : memref - # CHECK: } - # CHECK: } - # CHECK: } - # CHECK: scf.reduce - # CHECK: } - # CHECK: return - # CHECK: } - # CHECK: } - filecheck_with_comments(module)