From 6ce4cf3276428d5d88e8c5dd71076cd21afba18b Mon Sep 17 00:00:00 2001 From: makslevental Date: Wed, 19 Nov 2025 16:08:47 -0800 Subject: [PATCH 1/5] [eudsl-python-extras] better handle callable generics --- ...build_test_release_eudsl_python_extras.yml | 17 ++- .../mlir/extras/ast/util.py | 29 +++-- .../mlir/extras/dialects/func.py | 37 +++--- .../tests/dialect/test_func.py | 41 +++++++ .../tests/dialect/test_linalg.py | 116 ++++++++++++++++++ 5 files changed, 208 insertions(+), 32 deletions(-) diff --git a/.github/workflows/build_test_release_eudsl_python_extras.yml b/.github/workflows/build_test_release_eudsl_python_extras.yml index facaddb2..b989b8d3 100644 --- a/.github/workflows/build_test_release_eudsl_python_extras.yml +++ b/.github/workflows/build_test_release_eudsl_python_extras.yml @@ -94,7 +94,7 @@ jobs: "windows-2022" ] python-version: [ - # "3.10", "3.11", "3.12", + "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t" ] include: [ @@ -118,6 +118,21 @@ jobs: - runs-on: macos-13 python-version: "3.14t" + - runs-on: macos-14 + python-version: "3.10" + + - runs-on: macos-14 + python-version: "3.11" + + - runs-on: macos-14 + python-version: "3.12" + + - runs-on: macos-14 + python-version: "3.13" + + - runs-on: macos-14 + python-version: "3.14" + runs-on: ${{ matrix.runs-on }} name: "Test eudsl-python-extras ${{ matrix.name }} ${{ matrix.python-version }}" diff --git a/projects/eudsl-python-extras/mlir/extras/ast/util.py b/projects/eudsl-python-extras/mlir/extras/ast/util.py index b176eecf..58a1e089 100644 --- a/projects/eudsl-python-extras/mlir/extras/ast/util.py +++ b/projects/eudsl-python-extras/mlir/extras/ast/util.py @@ -38,9 +38,9 @@ def ast_call(name, args=None, keywords=None): def get_module_cst(f): f_src = dedent(inspect.getsource(f)) tree = ast.parse(f_src) - assert isinstance(tree.body[0], ast.FunctionDef), ( - f"unexpected ast node {tree.body[0]}" - ) + assert isinstance( + tree.body[0], ast.FunctionDef + ), f"unexpected ast node {tree.body[0]}" return tree @@ -130,18 +130,23 @@ def reducer_override(self, obj): return super().reducer_override(obj) +def copy_object(obj): + # see https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L813 + # for how this trick is accomplished (dill and pickle both fail to pickle eg generic typevars) + with io.BytesIO() as file: + cp = MLIRTypePickler(file) + cp.dump(obj) + obj = cloudpickle.loads(file.getvalue()) + return obj + + # Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard); # potentially more complete approach https://stackoverflow.com/a/56901529/9045206 def copy_func(f, new_closure: Dict = None): if new_closure is not None: code, closure = replace_closure(f.__code__, new_closure) else: - # see https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L813 - # for how this trick is accomplished (dill and pickle both fail to pickle eg generic typevars) - with io.BytesIO() as file: - cp = MLIRTypePickler(file) - cp.dump(f.__closure__) - closure = cloudpickle.loads(file.getvalue()) + closure = copy_object(f.__closure__) code = f.__code__ g = types.FunctionType( @@ -162,9 +167,9 @@ def copy_func(f, new_closure: Dict = None): def append_hidden_node(node_body, new_node): last_statement = node_body[-1] - assert last_statement.end_lineno is not None, ( - f"last_statement {ast.unparse(last_statement)} must have end_lineno" - ) + assert ( + last_statement.end_lineno is not None + ), f"last_statement {ast.unparse(last_statement)} must have end_lineno" new_node = ast.fix_missing_locations( set_lineno(new_node, last_statement.end_lineno) ) diff --git a/projects/eudsl-python-extras/mlir/extras/dialects/func.py b/projects/eudsl-python-extras/mlir/extras/dialects/func.py index 02fbabd7..e9658ec0 100644 --- a/projects/eudsl-python-extras/mlir/extras/dialects/func.py +++ b/projects/eudsl-python-extras/mlir/extras/dialects/func.py @@ -8,7 +8,7 @@ from functools import update_wrapper from typing import Optional, List, Union, TypeVar -from ..ast.util import copy_func +from ..ast.util import copy_func, copy_object from ..ast.py_type import PyTypeVarObject, _Ptr, PyObject from ..meta import op_region_builder from .. import types as T @@ -114,9 +114,7 @@ def prep_func_types(sig, return_types): return_types = list(return_types) assert all( isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in return_types - ), ( - f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}" - ) + ), f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}" input_types = [ p.annotation @@ -125,9 +123,7 @@ def prep_func_types(sig, return_types): ] assert all( isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in input_types - ), ( - f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}" - ) + ), f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}" user_loc = get_user_code_loc() # If ir.Context is none (like for deferred func emit) if user_loc is None: @@ -179,7 +175,7 @@ def __init__( self.call_op_ctor = call_op_ctor self.arg_attrs = arg_attrs self.res_attrs = res_attrs - self.generics = generics + self.generics = copy_object(generics) self.loc = loc self.ip = ip self._func_op = None @@ -200,9 +196,9 @@ def __init__( ) if self._is_decl(): - assert len(self.input_types) == len(sig.parameters), ( - f"func decl needs all input types annotated" - ) + assert len(self.input_types) == len( + sig.parameters + ), f"func decl needs all input types annotated" self.sym_visibility = "private" self.emit() @@ -374,11 +370,18 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): body_builder.__globals__[t.__name__] = r.val if r.name in body_builder.__code__.co_freevars: free_i = body_builder.__code__.co_freevars.index(r.name) - assert body_builder.__closure__[free_i].cell_contents == t, ( - "typevars don't match" - ) + assert ( + body_builder.__closure__[free_i].cell_contents == t + ), "typevars don't match" body_builder.__closure__[free_i].cell_contents = r.val + name_mangled_generics = [] + for r in reified_type_params: + t, v = r.type, r.val + if callable(v): + v = v.__name__ + name_mangled_generics.append(f"{t}_{v}") + return FuncBase( body_builder, self.func_op_ctor, @@ -386,11 +389,7 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): self.call_op_ctor, return_types=self.return_types, sym_visibility=self.sym_visibility, - sym_name=( - self.func_name - + "_" - + "_".join([f"{r.type}_{r.val}" for r in reified_type_params]) - ), + sym_name=(self.func_name + "_" + "_".join(name_mangled_generics)), arg_attrs=self.arg_attrs, res_attrs=self.res_attrs, func_attrs=self.func_attrs, diff --git a/projects/eudsl-python-extras/tests/dialect/test_func.py b/projects/eudsl-python-extras/tests/dialect/test_func.py index e725db50..931f2f01 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_func.py +++ b/projects/eudsl-python-extras/tests/dialect/test_func.py @@ -205,11 +205,52 @@ def mat_product_kernel( one = arith.constant(1, dtype) mat_product_kernel[32, 32, 32, T.i32()].emit() + mat_product_kernel[32, 32, 32, T.f32()].emit() # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_i32(%[[VAL_0:.*]]: memref<32x32xi32>, %[[VAL_1:.*]]: memref<32x32xi32>, %[[VAL_2:.*]]: memref<32x32xi32>) { # CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 # CHECK: return # CHECK: } + # CHECK: func.func @mat_product_kernel_int_32_int_32_int_32_type_f32(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: return + # CHECK: } + + filecheck_with_comments(ctx.module) + + +def test_generics_callable(ctx: MLIRContext): + _op = TypeVar("_op") + + @func(generics=[_op]) + def mat_product_kernel1(): + one = arith.constant(1, T.f32()) + two = _op(one, one) + + @func(generics=[_op]) + def mat_product_kernel2(): + one = arith.constant(1, T.f32()) + two = _op(one, one) + + mat_product_kernel1[arith.maximumf,].emit() + mat_product_kernel2[arith.minimumf,].emit() + mat_product_kernel2[arith.maximumf,].emit() + + # CHECK: func.func @mat_product_kernel1_function_maximumf() { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: %0 = arith.maximumf %cst, %cst : f32 + # CHECK: return + # CHECK: } + # CHECK: func.func @mat_product_kernel2_function_minimumf() { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: %0 = arith.minimumf %cst, %cst : f32 + # CHECK: return + # CHECK: } + # CHECK: func.func @mat_product_kernel2_function_maximumf() { + # CHECK: %cst = arith.constant 1.000000e+00 : f32 + # CHECK: %0 = arith.maximumf %cst, %cst : f32 + # CHECK: return + # CHECK: } filecheck_with_comments(ctx.module) diff --git a/projects/eudsl-python-extras/tests/dialect/test_linalg.py b/projects/eudsl-python-extras/tests/dialect/test_linalg.py index 0fd5c32c..d76fa93e 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_linalg.py +++ b/projects/eudsl-python-extras/tests/dialect/test_linalg.py @@ -18,6 +18,7 @@ filecheck_with_comments, mlir_ctx as ctx, ) +from mlir.extras.runtime.passes import Pipeline, run_pipeline # needed since the fix isn't defined here nor conftest.py pytest.mark.usefixtures("ctx") @@ -134,3 +135,118 @@ def maxpool3d( # CHECK: return # CHECK: } filecheck_with_comments(maxpool3d_k) + + +def test_pooling_ncdhw_max_parallel(ctx: MLIRContext): + S = ShapedType.get_dynamic_size() + + generics = ( + kernel_size_0, + kernel_size_1, + kernel_size_2, + stride_0, + stride_1, + stride_2, + dilation_0, + dilation_1, + dilation_2, + ) = list( + map( + TypeVar, + [ + "kernel_size_0", + "kernel_size_1", + "kernel_size_2", + "stride_0", + "stride_1", + "stride_2", + "dilation_0", + "dilation_1", + "dilation_2", + ], + ) + ) + + @func( + generics=( + kernel_size_0, + kernel_size_1, + kernel_size_2, + stride_0, + stride_1, + stride_2, + dilation_0, + dilation_1, + dilation_2, + ) + ) + def maxpool3d( + input: T.memref(S, S, S, S, S, T.f32()), + output: T.memref(S, S, S, S, S, T.f32()), + ): + kernel_shape_surrogate = memref.alloca( + (kernel_size_0, kernel_size_1, kernel_size_2), + T.f32(), + ) + + linalg.pooling_ncdhw_max( + input, + kernel_shape_surrogate, + output, + strides=[stride_0, stride_1, stride_2], + dilations=[dilation_0, dilation_1, dilation_2], + ) + + kernel_sizes = [1, 2, 3] + strides = [4, 5, 6] + dilations = [7, 8, 9] + maxpool3d_k = maxpool3d[ + kernel_sizes[0], + kernel_sizes[1], + kernel_sizes[2], + strides[0], + strides[1], + strides[2], + dilations[0], + dilations[1], + dilations[2], + ].emit() + module = run_pipeline( + ctx.module, + Pipeline().bufferize().Func(Pipeline().convert_linalg_to_parallel_loops()), + ) + # CHECK: #map = affine_map<(d0, d1) -> (d0 * 4 + d1 * 7)> + # CHECK: #map1 = affine_map<(d0, d1) -> (d0 * 5 + d1 * 8)> + # CHECK: #map2 = affine_map<(d0, d1) -> (d0 * 6 + d1 * 9)> + # CHECK: module { + # CHECK: func.func @maxpool3d_int_1_int_2_int_3_int_4_int_5_int_6_int_7_int_8_int_9(%arg0: memref, %arg1: memref) { + # CHECK: %c4 = arith.constant 4 : index + # CHECK: %c3 = arith.constant 3 : index + # CHECK: %c2 = arith.constant 2 : index + # CHECK: %c1 = arith.constant 1 : index + # CHECK: %c0 = arith.constant 0 : index + # CHECK: %dim = memref.dim %arg0, %c0 : memref + # CHECK: %dim_0 = memref.dim %arg0, %c1 : memref + # CHECK: %dim_1 = memref.dim %arg1, %c2 : memref + # CHECK: %dim_2 = memref.dim %arg1, %c3 : memref + # CHECK: %dim_3 = memref.dim %arg1, %c4 : memref + # CHECK: scf.parallel (%arg2, %arg3, %arg4, %arg5, %arg6) = (%c0, %c0, %c0, %c0, %c0) to (%dim, %dim_0, %dim_1, %dim_2, %dim_3) step (%c1, %c1, %c1, %c1, %c1) { + # CHECK: scf.for %arg7 = %c0 to %c1 step %c1 { + # CHECK: scf.for %arg8 = %c0 to %c2 step %c1 { + # CHECK: scf.for %arg9 = %c0 to %c3 step %c1 { + # CHECK: %0 = affine.apply #map(%arg4, %arg7) + # CHECK: %1 = affine.apply #map1(%arg5, %arg8) + # CHECK: %2 = affine.apply #map2(%arg6, %arg9) + # CHECK: %3 = memref.load %arg0[%arg2, %arg3, %0, %1, %2] : memref + # CHECK: %4 = memref.load %arg1[%arg2, %arg3, %arg4, %arg5, %arg6] : memref + # CHECK: %5 = arith.maximumf %3, %4 : f32 + # CHECK: memref.store %5, %arg1[%arg2, %arg3, %arg4, %arg5, %arg6] : memref + # CHECK: } + # CHECK: } + # CHECK: } + # CHECK: scf.reduce + # CHECK: } + # CHECK: return + # CHECK: } + # CHECK: } + filecheck_with_comments(module) From 86b71826409ea489d64d79d8b975ee07d7411a0d Mon Sep 17 00:00:00 2001 From: makslevental Date: Wed, 19 Nov 2025 16:22:35 -0800 Subject: [PATCH 2/5] fix strenum --- .../mlir/extras/runtime/_passes_base.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/projects/eudsl-python-extras/mlir/extras/runtime/_passes_base.py b/projects/eudsl-python-extras/mlir/extras/runtime/_passes_base.py index 7c576ef4..58bec28f 100644 --- a/projects/eudsl-python-extras/mlir/extras/runtime/_passes_base.py +++ b/projects/eudsl-python-extras/mlir/extras/runtime/_passes_base.py @@ -6,7 +6,6 @@ import sys import tempfile from contextlib import ExitStack -from enum import StrEnum from io import StringIO from typing import List, Optional, Union @@ -14,6 +13,16 @@ from ...ir import Module, StringAttr from ...passmanager import PassManager + +try: + from enum import StrEnum +except ImportError: + from enum import Enum + + class StrEnum(str, Enum): + pass + + logger = logging.getLogger(__name__) From db68c5965bb2d00655ecad83cde68a45aaf0f2c8 Mon Sep 17 00:00:00 2001 From: makslevental Date: Wed, 19 Nov 2025 16:35:43 -0800 Subject: [PATCH 3/5] failing case --- .../tests/dialect/test_func.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/projects/eudsl-python-extras/tests/dialect/test_func.py b/projects/eudsl-python-extras/tests/dialect/test_func.py index 931f2f01..5ceec64b 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_func.py +++ b/projects/eudsl-python-extras/tests/dialect/test_func.py @@ -255,6 +255,30 @@ def mat_product_kernel2(): filecheck_with_comments(ctx.module) +_op = TypeVar("_op") + + +def test_global_closures(ctx: MLIRContext): + globals()["_op"] = TypeVar("_op") + + @func(generics=[_op]) + def _generic_pool2d_scf(a: T.f32(), b: T.f32()): + _op(a, b) + + _maxpool2d_scf = _generic_pool2d_scf[arith.maximumf,] + + # _op = TypeVar("_op") + + @func(generics=[_op]) + def _generic_pool3d_scf( + a: T.f32(), + b: T.f32(), + ): + _op(a, b) + + _maxpool3d_scf = _generic_pool3d_scf[arith.maximumf,] + + def test_generics_with_canonicalizations(ctx: MLIRContext): generics = M, K, N, dtype = list(map(TypeVar, ["M", "K", "N", "dtype"])) From ce4c3e313f305b8698d599aa24a38d3720e7080c Mon Sep 17 00:00:00 2001 From: makslevental Date: Wed, 19 Nov 2025 19:38:36 -0800 Subject: [PATCH 4/5] check for failing case --- .../mlir/extras/ast/util.py | 14 +++--- .../mlir/extras/dialects/func.py | 50 ++++++++++++------- projects/eudsl-python-extras/requirements.txt | 2 +- .../tests/dialect/test_func.py | 14 ++++-- 4 files changed, 48 insertions(+), 32 deletions(-) diff --git a/projects/eudsl-python-extras/mlir/extras/ast/util.py b/projects/eudsl-python-extras/mlir/extras/ast/util.py index 58a1e089..32f5989c 100644 --- a/projects/eudsl-python-extras/mlir/extras/ast/util.py +++ b/projects/eudsl-python-extras/mlir/extras/ast/util.py @@ -38,9 +38,9 @@ def ast_call(name, args=None, keywords=None): def get_module_cst(f): f_src = dedent(inspect.getsource(f)) tree = ast.parse(f_src) - assert isinstance( - tree.body[0], ast.FunctionDef - ), f"unexpected ast node {tree.body[0]}" + assert isinstance(tree.body[0], ast.FunctionDef), ( + f"unexpected ast node {tree.body[0]}" + ) return tree @@ -92,7 +92,7 @@ def replace_closure(code, new_closure: Dict): LOAD_DEREF = opmap["LOAD_DEREF"] # get the orig localplus that will be loaded from by the orig bytecode LOAD_DEREF arg_i - localsplus, localsplus_name_to_idx = get_localsplus_name_to_idx(code) + localsplus, _localsplus_name_to_idx = get_localsplus_name_to_idx(code) # closure vars go into co_freevars new_code = code.replace(co_freevars=tuple(new_closure.keys())) @@ -167,9 +167,9 @@ def copy_func(f, new_closure: Dict = None): def append_hidden_node(node_body, new_node): last_statement = node_body[-1] - assert ( - last_statement.end_lineno is not None - ), f"last_statement {ast.unparse(last_statement)} must have end_lineno" + assert last_statement.end_lineno is not None, ( + f"last_statement {ast.unparse(last_statement)} must have end_lineno" + ) new_node = ast.fix_missing_locations( set_lineno(new_node, last_statement.end_lineno) ) diff --git a/projects/eudsl-python-extras/mlir/extras/dialects/func.py b/projects/eudsl-python-extras/mlir/extras/dialects/func.py index e9658ec0..6bcbd72b 100644 --- a/projects/eudsl-python-extras/mlir/extras/dialects/func.py +++ b/projects/eudsl-python-extras/mlir/extras/dialects/func.py @@ -114,7 +114,9 @@ def prep_func_types(sig, return_types): return_types = list(return_types) assert all( isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in return_types - ), f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}" + ), ( + f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}" + ) input_types = [ p.annotation @@ -123,7 +125,9 @@ def prep_func_types(sig, return_types): ] assert all( isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in input_types - ), f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}" + ), ( + f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}" + ) user_loc = get_user_code_loc() # If ir.Context is none (like for deferred func emit) if user_loc is None: @@ -196,9 +200,9 @@ def __init__( ) if self._is_decl(): - assert len(self.input_types) == len( - sig.parameters - ), f"func decl needs all input types annotated" + assert len(self.input_types) == len(sig.parameters), ( + f"func decl needs all input types annotated" + ) self.sym_visibility = "private" self.emit() @@ -298,6 +302,8 @@ def __call__(self, *call_args): return call(self.emit(*call_args), call_args) def __getitem__(self, item): + if not isinstance(item, tuple): + item = (item,) if self.generics is None: raise RuntimeError( "using a generic call requires the func be generic (i.e., have type_params)" @@ -322,22 +328,26 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): continue if k not in already_reified_type_params: raise RuntimeError( - f"typevar {k} not reified prior to evaluating dependent typevar {t}" + f"typevar {k} not reified prior to evaluating dependent typevar {tvar}" ) cvrs[k] = already_reified_type_params[k] unevaled_type_data = copy_func(unevaled_type_data, cvrs) return unevaled_type_data() - generics = copy.deepcopy(self.generics) - for i, t in enumerate(generics): + generics = copy_object(self.generics) + for i, tvar in enumerate(generics): + if not isinstance(tvar, TypeVar): + raise RuntimeError( + f"{i}th generic has probably already been reified as {tvar}; if you're using a global tvar for the generic, you should give it a unique name." + ) type_var_default = None if sys.version_info >= (3, 12): - type_var = PyTypeVarObject.from_object(t) + type_var = PyTypeVarObject.from_object(tvar) type_var_bound = type_var.bound - if sys.version_info >= (3, 13) and t.has_default(): + if sys.version_info >= (3, 13) and tvar.has_default(): type_var_default = type_var.default_value else: - type_var_bound = t.__bound__ + type_var_bound = tvar.__bound__ if bool(type_var_bound): # before 3.12 typevar was just a python class @@ -346,7 +356,7 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): type_var_bound = maybe_eval_type_data_closure_vals(type_var_bound) elif not bool(type_var_default): if i >= len(item): - raise RuntimeError(f"generic {t} must have concrete val") + raise RuntimeError(f"generic {tvar=} must have concrete val") if isinstance(item[i], Type): type_var_bound = "type" else: @@ -358,29 +368,31 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): val = type_var_default else: if i >= len(item): - raise RuntimeError(f"generic {t} must have concrete val") + raise RuntimeError(f"generic {tvar=} must have concrete val") val = item[i] - r = ReifiedTypeParams(t.__name__, val, type_var_bound) + r = ReifiedTypeParams(tvar.__name__, val, type_var_bound) reified_type_params.append(r) already_reified_type_params[r.name] = r.val - if t.__name__ in body_builder.__globals__: - body_builder.__globals__[t.__name__] = r.val + # replace the tvar in body_builder's global context with the reified val + if tvar.__name__ in body_builder.__globals__: + body_builder.__globals__[tvar.__name__] = r.val + # replace the tvar in body_builder's closure with the reified val if r.name in body_builder.__code__.co_freevars: free_i = body_builder.__code__.co_freevars.index(r.name) assert ( - body_builder.__closure__[free_i].cell_contents == t + body_builder.__closure__[free_i].cell_contents == tvar ), "typevars don't match" body_builder.__closure__[free_i].cell_contents = r.val name_mangled_generics = [] for r in reified_type_params: - t, v = r.type, r.val + tvar, v = r.type, r.val if callable(v): v = v.__name__ - name_mangled_generics.append(f"{t}_{v}") + name_mangled_generics.append(f"{tvar}_{v}") return FuncBase( body_builder, diff --git a/projects/eudsl-python-extras/requirements.txt b/projects/eudsl-python-extras/requirements.txt index a1bbc682..6e31ad2c 100644 --- a/projects/eudsl-python-extras/requirements.txt +++ b/projects/eudsl-python-extras/requirements.txt @@ -1,4 +1,4 @@ PyYAML>=5.4.0 -bytecode @ git+https://github.com/MatthieuDartiailh/bytecode +bytecode>=0.17.0 cloudpickle>=3.0.0 numpy>=1.19.5, <=2.1.2 diff --git a/projects/eudsl-python-extras/tests/dialect/test_func.py b/projects/eudsl-python-extras/tests/dialect/test_func.py index 5ceec64b..b3b4b767 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_func.py +++ b/projects/eudsl-python-extras/tests/dialect/test_func.py @@ -232,9 +232,9 @@ def mat_product_kernel2(): one = arith.constant(1, T.f32()) two = _op(one, one) - mat_product_kernel1[arith.maximumf,].emit() - mat_product_kernel2[arith.minimumf,].emit() - mat_product_kernel2[arith.maximumf,].emit() + mat_product_kernel1[arith.maximumf].emit() + mat_product_kernel2[arith.minimumf].emit() + mat_product_kernel2[arith.maximumf].emit() # CHECK: func.func @mat_product_kernel1_function_maximumf() { # CHECK: %cst = arith.constant 1.000000e+00 : f32 @@ -265,7 +265,7 @@ def test_global_closures(ctx: MLIRContext): def _generic_pool2d_scf(a: T.f32(), b: T.f32()): _op(a, b) - _maxpool2d_scf = _generic_pool2d_scf[arith.maximumf,] + _maxpool2d_scf = _generic_pool2d_scf[arith.maximumf] # _op = TypeVar("_op") @@ -276,7 +276,11 @@ def _generic_pool3d_scf( ): _op(a, b) - _maxpool3d_scf = _generic_pool3d_scf[arith.maximumf,] + with pytest.raises( + RuntimeError, + match="0th generic has probably already been reified as ; if you're using a global tvar for the generic, you should give it a unique name.", + ): + _maxpool3d_scf = _generic_pool3d_scf[arith.maximumf] def test_generics_with_canonicalizations(ctx: MLIRContext): From 232d19a6870ce82c84e849eb7ac08e22d16e901d Mon Sep 17 00:00:00 2001 From: makslevental Date: Wed, 19 Nov 2025 19:49:58 -0800 Subject: [PATCH 5/5] fix GHA too --- .../build_mlir_python_bindings_wheel.yml | 17 +++++++- .../workflows/build_test_release_eudsl.yml | 42 ++++++++++++++++--- .github/workflows/clean_releases.yml | 2 +- projects/eudsl-llvmpy/src/llvm/function.py | 11 ++--- .../mlir/extras/dialects/func.py | 28 ++++++------- .../tests/dialect/test_func.py | 16 +++++-- 6 files changed, 84 insertions(+), 32 deletions(-) diff --git a/.github/workflows/build_mlir_python_bindings_wheel.yml b/.github/workflows/build_mlir_python_bindings_wheel.yml index 16fe30a6..0e42ed48 100644 --- a/.github/workflows/build_mlir_python_bindings_wheel.yml +++ b/.github/workflows/build_mlir_python_bindings_wheel.yml @@ -206,7 +206,7 @@ jobs: "windows-2022" ] python-version: [ - # "3.10", "3.11", "3.12", + "3.10", "3.11", "3.12", "3.13", "3.14", "3.14t" ] include: [ @@ -230,6 +230,21 @@ jobs: - runs-on: macos-13 python-version: "3.14t" + - runs-on: macos-14 + python-version: "3.10" + + - runs-on: macos-14 + python-version: "3.11" + + - runs-on: macos-14 + python-version: "3.12" + + - runs-on: macos-14 + python-version: "3.13" + + - runs-on: macos-14 + python-version: "3.14" + runs-on: ${{ matrix.runs-on }} name: "Test mlir-python-bindings ${{ matrix.name }} ${{ matrix.python-version }}" diff --git a/.github/workflows/build_test_release_eudsl.yml b/.github/workflows/build_test_release_eudsl.yml index ff26c344..bb8bb808 100644 --- a/.github/workflows/build_test_release_eudsl.yml +++ b/.github/workflows/build_test_release_eudsl.yml @@ -277,8 +277,8 @@ jobs: # "macos-13", "macos-14", "windows-2022"] python-version: [ - # "3.9", "3.10", "3.11", - "3.12", "3.13" + "3.10", "3.11", "3.12", + "3.13", "3.14", "3.14t" ] include: [ {runs-on: "ubuntu-22.04", name: "ubuntu_x86_64", os: "ubuntu"}, @@ -291,6 +291,21 @@ jobs: - runs-on: macos-13 python-version: "3.13" + - runs-on: macos-14 + python-version: "3.10" + + - runs-on: macos-14 + python-version: "3.11" + + - runs-on: macos-14 + python-version: "3.12" + + - runs-on: macos-14 + python-version: "3.13" + + - runs-on: macos-14 + python-version: "3.14" + runs-on: ${{ matrix.runs-on }} name: "Test tblgen ${{ matrix.name }} ${{ matrix.python-version }}" @@ -306,7 +321,7 @@ jobs: submodules: false - name: "Install Python" - uses: actions/setup-python@v4 + uses: actions/setup-python@v6.0.0 with: python-version: "${{ matrix.python-version }}" @@ -348,8 +363,11 @@ jobs: "windows-2022" ] python-version: [ + # we only build 3.12 and up # "3.10", "3.11", - "3.12", "3.13" + "3.12", "3.13", "3.14", + # stable abi doesn't support 3.14t? + # "3.14t" ] include: [ {runs-on: "ubuntu-22.04", name: "ubuntu_x86_64", os: "ubuntu"}, @@ -362,6 +380,18 @@ jobs: - runs-on: macos-13 python-version: "3.13" + - runs-on: macos-14 + python-version: "3.10" + + - runs-on: macos-14 + python-version: "3.11" + + - runs-on: macos-14 + python-version: "3.12" + + - runs-on: macos-14 + python-version: "3.13" + runs-on: ${{ matrix.runs-on }} name: "Test llvmpy ${{ matrix.name }} ${{ matrix.python-version }}" @@ -377,7 +407,7 @@ jobs: submodules: false - name: "Install Python" - uses: actions/setup-python@v4 + uses: actions/setup-python@v6.0.0 with: python-version: "${{ matrix.python-version }}" @@ -445,7 +475,7 @@ jobs: if: (github.event_name == 'push' && github.ref_name == 'main') || github.event_name == 'workflow_dispatch' - needs: [test-eudsl-llvmpy, test-eudsl-tblgen] + needs: [release-eudsl] permissions: contents: read diff --git a/.github/workflows/clean_releases.yml b/.github/workflows/clean_releases.yml index e9312db5..ebc94b56 100644 --- a/.github/workflows/clean_releases.yml +++ b/.github/workflows/clean_releases.yml @@ -21,7 +21,7 @@ jobs: uses: actions/checkout@v2 - name: Setup Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v6.0.0 with: python-version: "3.10" diff --git a/projects/eudsl-llvmpy/src/llvm/function.py b/projects/eudsl-llvmpy/src/llvm/function.py index 85c12061..a17f94cd 100644 --- a/projects/eudsl-llvmpy/src/llvm/function.py +++ b/projects/eudsl-llvmpy/src/llvm/function.py @@ -65,16 +65,17 @@ def __init__(self, body_builder, *, return_type=None, entry_bb_name="entry"): def _is_decl(self): # magic constant found from looking at the code for an empty fn + if sys.version_info.minor == 14: + return self.body_builder.__code__.co_code == b"\x80\x00R\x00#\x00" if sys.version_info.minor == 13: return self.body_builder.__code__.co_code == b"\x95\x00g\x00" - elif sys.version_info.minor == 12: + if sys.version_info.minor == 12: return self.body_builder.__code__.co_code == b"\x97\x00y\x00" - elif sys.version_info.minor == 11: + if sys.version_info.minor == 11: return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00" - elif sys.version_info.minor in {8, 9, 10}: + if sys.version_info.minor in {8, 9, 10}: return self.body_builder.__code__.co_code == b"d\x00S\x00" - else: - raise NotImplementedError(f"{sys.version_info.minor} not supported.") + raise NotImplementedError(f"{sys.version_info.minor} not supported.") def __str__(self): return str(f"{self.__class__} {self.__dict__}") diff --git a/projects/eudsl-python-extras/mlir/extras/dialects/func.py b/projects/eudsl-python-extras/mlir/extras/dialects/func.py index 6bcbd72b..774b724b 100644 --- a/projects/eudsl-python-extras/mlir/extras/dialects/func.py +++ b/projects/eudsl-python-extras/mlir/extras/dialects/func.py @@ -114,9 +114,7 @@ def prep_func_types(sig, return_types): return_types = list(return_types) assert all( isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in return_types - ), ( - f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}" - ) + ), f"all return types must be mlir types or strings or TypeVars or lambdas {return_types=}" input_types = [ p.annotation @@ -125,9 +123,7 @@ def prep_func_types(sig, return_types): ] assert all( isinstance(r, (str, Type, TypeVar)) or isalambda(r) for r in input_types - ), ( - f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}" - ) + ), f"all input types must be mlir types or strings or TypeVars or lambdas {input_types=}" user_loc = get_user_code_loc() # If ir.Context is none (like for deferred func emit) if user_loc is None: @@ -200,9 +196,9 @@ def __init__( ) if self._is_decl(): - assert len(self.input_types) == len(sig.parameters), ( - f"func decl needs all input types annotated" - ) + assert len(self.input_types) == len( + sig.parameters + ), f"func decl needs all input types annotated" self.sym_visibility = "private" self.emit() @@ -210,16 +206,15 @@ def _is_decl(self): # magic constant found from looking at the code for an empty fn if sys.version_info.minor == 14: return self.body_builder.__code__.co_code == b"\x80\x00R\x00#\x00" - elif sys.version_info.minor == 13: + if sys.version_info.minor == 13: return self.body_builder.__code__.co_code == b"\x95\x00g\x00" - elif sys.version_info.minor == 12: + if sys.version_info.minor == 12: return self.body_builder.__code__.co_code == b"\x97\x00y\x00" - elif sys.version_info.minor == 11: + if sys.version_info.minor == 11: return self.body_builder.__code__.co_code == b"\x97\x00d\x00S\x00" - elif sys.version_info.minor in {8, 9, 10}: + if sys.version_info.minor in {8, 9, 10}: return self.body_builder.__code__.co_code == b"d\x00S\x00" - else: - raise NotImplementedError(f"{sys.version_info.minor} not supported.") + raise NotImplementedError(f"{sys.version_info.minor} not supported.") def __str__(self): return str(f"{self.__class__} {self.__dict__}") @@ -338,7 +333,8 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]): for i, tvar in enumerate(generics): if not isinstance(tvar, TypeVar): raise RuntimeError( - f"{i}th generic has probably already been reified as {tvar}; if you're using a global tvar for the generic, you should give it a unique name." + f"{i}th generic has probably already been reified as {tvar}; if you're using a global tvar for the generic, " + f"you should use a unique one for each generic function." ) type_var_default = None if sys.version_info >= (3, 12): diff --git a/projects/eudsl-python-extras/tests/dialect/test_func.py b/projects/eudsl-python-extras/tests/dialect/test_func.py index b3b4b767..261c18e3 100644 --- a/projects/eudsl-python-extras/tests/dialect/test_func.py +++ b/projects/eudsl-python-extras/tests/dialect/test_func.py @@ -267,8 +267,6 @@ def _generic_pool2d_scf(a: T.f32(), b: T.f32()): _maxpool2d_scf = _generic_pool2d_scf[arith.maximumf] - # _op = TypeVar("_op") - @func(generics=[_op]) def _generic_pool3d_scf( a: T.f32(), @@ -278,10 +276,22 @@ def _generic_pool3d_scf( with pytest.raises( RuntimeError, - match="0th generic has probably already been reified as ; if you're using a global tvar for the generic, you should give it a unique name.", + match="0th generic has probably already been reified as ; if you're using a global tvar for the generic," + " you should use a unique one for each generic function.", ): _maxpool3d_scf = _generic_pool3d_scf[arith.maximumf] + _op1 = TypeVar("_op1") + + @func(generics=[_op1]) + def _generic_pool3d_scf( + a: T.f32(), + b: T.f32(), + ): + _op(a, b) + + _maxpool3d_scf = _generic_pool3d_scf[arith.maximumf] + def test_generics_with_canonicalizations(ctx: MLIRContext): generics = M, K, N, dtype = list(map(TypeVar, ["M", "K", "N", "dtype"]))