Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/build_test_release_eudsl_python_extras.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ jobs:

- name: "Build eudsl-python-extras sdist"
run: |

SHA_SHORT="$(git rev-parse --short HEAD)"
WHEEL_VERSION="$(date +'%Y%m%d.%H%M')+$SHA_SHORT"
pushd projects/eudsl-python-extras
Expand Down Expand Up @@ -175,7 +176,14 @@ jobs:
run: python -m pip install dist/eudsl_python_extras*.tar.gz

- name: "Test eudsl-python-extras"
run: python -m pytest projects/eudsl-python-extras/tests
run: |

IGNORE=""
if [[ $(python -c "print(__import__('sys').version_info < (3, 13))") == "True" ]]; then
IGNORE="--ignore projects/eudsl-python-extras/tests/dialect/test_generics.py"
fi

python -m pytest projects/eudsl-python-extras/tests $IGNORE

release-eudsl-python-extras:

Expand Down
1 change: 1 addition & 0 deletions projects/eudsl-python-extras/mlir/extras/ast/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from bytecode import ConcreteBytecode
from cloudpickle import cloudpickle

from ...ir import Type


Expand Down
54 changes: 19 additions & 35 deletions projects/eudsl-python-extras/mlir/extras/dialects/func.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import copy
import inspect
import sys
from dataclasses import dataclass
from functools import update_wrapper
from typing import Optional, List, Union, TypeVar

from ..ast.util import copy_func, copy_object
from .. import types as T
from ..ast.py_type import PyTypeVarObject, _Ptr, PyObject
from ..ast.util import copy_func
from ..meta import op_region_builder
from .. import types as T
from ..util import get_user_code_loc, make_maybe_no_args_decorator
from ...dialects._ods_common import get_op_result_or_op_results
from ...dialects.func import *
Expand All @@ -27,7 +26,6 @@
Value,
)


_call = call


Expand Down Expand Up @@ -175,7 +173,7 @@ def __init__(
self.call_op_ctor = call_op_ctor
self.arg_attrs = arg_attrs
self.res_attrs = res_attrs
self.generics = copy_object(generics)
self.generics = generics
self.loc = loc
self.ip = ip
self._func_op = None
Expand Down Expand Up @@ -323,33 +321,24 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]):
continue
if k not in already_reified_type_params:
raise RuntimeError(
f"typevar {k} not reified prior to evaluating dependent typevar {tvar}"
f"typevar {k} not reified prior to evaluating dependent typevar {v}"
)
cvrs[k] = already_reified_type_params[k]
unevaled_type_data = copy_func(unevaled_type_data, cvrs)
return unevaled_type_data()

generics = copy_object(self.generics)
for i, tvar in enumerate(generics):
if not isinstance(tvar, TypeVar):
raise RuntimeError(
f"{i}th generic has probably already been reified as {tvar}; if you're using a global tvar for the generic, "
f"you should use a unique one for each generic function."
)
for i, tvar in enumerate(self.generics):
if tvar.__name__ in body_builder.__globals__:
raise RuntimeError("global typevars for generics are not supported")

type_var_default = None
if sys.version_info >= (3, 12):
type_var = PyTypeVarObject.from_object(tvar)
type_var_bound = type_var.bound
if sys.version_info >= (3, 13) and tvar.has_default():
type_var_default = type_var.default_value
else:
type_var_bound = tvar.__bound__
type_var = PyTypeVarObject.from_object(tvar)
type_var_bound = type_var.bound
if sys.version_info >= (3, 13) and tvar.has_default():
type_var_default = type_var.default_value

if bool(type_var_bound):
# before 3.12 typevar was just a python class
# https://github.com/python/cpython/blob/3.11/Lib/typing.py#L966
if sys.version_info >= (3, 12):
type_var_bound = maybe_eval_type_data_closure_vals(type_var_bound)
type_var_bound = maybe_eval_type_data_closure_vals(type_var_bound)
elif not bool(type_var_default):
if i >= len(item):
raise RuntimeError(f"generic {tvar=} must have concrete val")
Expand All @@ -372,15 +361,13 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]):
reified_type_params.append(r)
already_reified_type_params[r.name] = r.val

# replace the tvar in body_builder's global context with the reified val
if tvar.__name__ in body_builder.__globals__:
body_builder.__globals__[tvar.__name__] = r.val
# replace the tvar in body_builder's closure with the reified val
# only in the closure if used in the body
if r.name in body_builder.__code__.co_freevars:
free_i = body_builder.__code__.co_freevars.index(r.name)
assert (
body_builder.__closure__[free_i].cell_contents == tvar
), "typevars don't match"
if body_builder.__closure__[free_i].cell_contents != tvar:
raise RuntimeError(
f"typevars don't match: {id(body_builder.__closure__[free_i].cell_contents)=}, {id(tvar)=}"
)
body_builder.__closure__[free_i].cell_contents = r.val

name_mangled_generics = []
Expand Down Expand Up @@ -419,12 +406,9 @@ def func(
func_attrs=None,
function_type=None,
emit=False,
generics=None,
loc=None,
ip=None,
) -> FuncBase:
if generics is None and hasattr(f, "__type_params__") and f.__type_params__:
generics = f.__type_params__
func_ = FuncBase(
body_builder=f,
func_op_ctor=FuncOp.__base__,
Expand All @@ -436,7 +420,7 @@ def func(
res_attrs=res_attrs,
func_attrs=func_attrs,
function_type=function_type,
generics=generics,
generics=getattr(f, "__type_params__", None),
loc=loc,
ip=ip,
)
Expand Down
10 changes: 5 additions & 5 deletions projects/eudsl-python-extras/mlir/extras/dialects/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from functools import partial
from typing import Any, List, Optional, Tuple, Union


from .func import FuncBase
from .. import types as T
from ..meta import (
Expand All @@ -25,6 +24,10 @@
get_op_result_or_op_results,
)
from ...dialects.gpu import *

del constant
# constant needs to be below gpu import because it needs to shadow upstream's arith.constant
# noinspection PyUnusedImports
from .arith import constant
from ...ir import (
ArrayAttr,
Expand Down Expand Up @@ -439,13 +442,10 @@ def func(
res_attrs=None,
func_attrs=None,
emit=False,
generics=None,
loc=None,
ip=None,
emit_grid=False,
) -> Grid:
if generics is None and hasattr(f, "__type_params__") and f.__type_params__:
generics = f.__type_params__
func_ = GPUFunc(
body_builder=f,
func_op_ctor=GPUFuncOp,
Expand All @@ -455,7 +455,7 @@ def func(
arg_attrs=arg_attrs,
res_attrs=res_attrs,
func_attrs=func_attrs,
generics=generics,
generics=getattr(f, "__type_params__", None),
loc=loc,
ip=ip,
)
Expand Down
Loading