Skip to content

Commit bd00241

Browse files
authored
[eudsl-python-extras] drop support for generics for py < 3.12 (#236)
After much consternation and flaggelation, I've decided we can't/shouldn't support generics for python < 3.12. The context/reason is that I thought I had robust support for just plain `TypeVar`s as globals (and in the closure) but recently I've realized that that support is not robust. In the weeds: the "generic" kernel instantiation effectively mutates the `TypeVar` and thus when the `TypeVar` is a global, multiple instantiations of the same kernel (with different concrete params) will race (on that global `TypeVar`). `TypeVar`s in the closure don't suffer from this because we copy the `body_builder` and its closure when instantiating the kernel. In 3.12 (when using the brackets `def fun[M, N, ...]` syntax) the `TypeVar`s are always in the closure[^1]. I tried to figure out somehow to move globals into closures but failed and decided it's not worth the effort anyway. Note, the added test (`test_generics`) is gated behind 3.13 because generics default value assignment is a 3.13 feature and the 3.12 parser will choke on it, but everything else works in 3.12. [^1]: That's actually a CPython implementation detail - there's a compiler transform which rewrites the function into a closed over function with those TypeVars in its closure.
1 parent 82e1725 commit bd00241

File tree

8 files changed

+620
-894
lines changed

8 files changed

+620
-894
lines changed

.github/workflows/build_test_release_eudsl_python_extras.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ jobs:
6464

6565
- name: "Build eudsl-python-extras sdist"
6666
run: |
67+
6768
SHA_SHORT="$(git rev-parse --short HEAD)"
6869
WHEEL_VERSION="$(date +'%Y%m%d.%H%M')+$SHA_SHORT"
6970
pushd projects/eudsl-python-extras
@@ -175,7 +176,14 @@ jobs:
175176
run: python -m pip install dist/eudsl_python_extras*.tar.gz
176177

177178
- name: "Test eudsl-python-extras"
178-
run: python -m pytest projects/eudsl-python-extras/tests
179+
run: |
180+
181+
IGNORE=""
182+
if [[ $(python -c "print(__import__('sys').version_info < (3, 13))") == "True" ]]; then
183+
IGNORE="--ignore projects/eudsl-python-extras/tests/dialect/test_generics.py"
184+
fi
185+
186+
python -m pytest projects/eudsl-python-extras/tests $IGNORE
179187
180188
release-eudsl-python-extras:
181189

projects/eudsl-python-extras/mlir/extras/ast/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from bytecode import ConcreteBytecode
1414
from cloudpickle import cloudpickle
15+
1516
from ...ir import Type
1617

1718

projects/eudsl-python-extras/mlir/extras/dialects/func.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
22
# See https://llvm.org/LICENSE.txt for license information.
33
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4-
import copy
54
import inspect
65
import sys
76
from dataclasses import dataclass
87
from functools import update_wrapper
98
from typing import Optional, List, Union, TypeVar
109

11-
from ..ast.util import copy_func, copy_object
10+
from .. import types as T
1211
from ..ast.py_type import PyTypeVarObject, _Ptr, PyObject
12+
from ..ast.util import copy_func
1313
from ..meta import op_region_builder
14-
from .. import types as T
1514
from ..util import get_user_code_loc, make_maybe_no_args_decorator
1615
from ...dialects._ods_common import get_op_result_or_op_results
1716
from ...dialects.func import *
@@ -27,7 +26,6 @@
2726
Value,
2827
)
2928

30-
3129
_call = call
3230

3331

@@ -175,7 +173,7 @@ def __init__(
175173
self.call_op_ctor = call_op_ctor
176174
self.arg_attrs = arg_attrs
177175
self.res_attrs = res_attrs
178-
self.generics = copy_object(generics)
176+
self.generics = generics
179177
self.loc = loc
180178
self.ip = ip
181179
self._func_op = None
@@ -323,33 +321,24 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]):
323321
continue
324322
if k not in already_reified_type_params:
325323
raise RuntimeError(
326-
f"typevar {k} not reified prior to evaluating dependent typevar {tvar}"
324+
f"typevar {k} not reified prior to evaluating dependent typevar {v}"
327325
)
328326
cvrs[k] = already_reified_type_params[k]
329327
unevaled_type_data = copy_func(unevaled_type_data, cvrs)
330328
return unevaled_type_data()
331329

332-
generics = copy_object(self.generics)
333-
for i, tvar in enumerate(generics):
334-
if not isinstance(tvar, TypeVar):
335-
raise RuntimeError(
336-
f"{i}th generic has probably already been reified as {tvar}; if you're using a global tvar for the generic, "
337-
f"you should use a unique one for each generic function."
338-
)
330+
for i, tvar in enumerate(self.generics):
331+
if tvar.__name__ in body_builder.__globals__:
332+
raise RuntimeError("global typevars for generics are not supported")
333+
339334
type_var_default = None
340-
if sys.version_info >= (3, 12):
341-
type_var = PyTypeVarObject.from_object(tvar)
342-
type_var_bound = type_var.bound
343-
if sys.version_info >= (3, 13) and tvar.has_default():
344-
type_var_default = type_var.default_value
345-
else:
346-
type_var_bound = tvar.__bound__
335+
type_var = PyTypeVarObject.from_object(tvar)
336+
type_var_bound = type_var.bound
337+
if sys.version_info >= (3, 13) and tvar.has_default():
338+
type_var_default = type_var.default_value
347339

348340
if bool(type_var_bound):
349-
# before 3.12 typevar was just a python class
350-
# https://github.com/python/cpython/blob/3.11/Lib/typing.py#L966
351-
if sys.version_info >= (3, 12):
352-
type_var_bound = maybe_eval_type_data_closure_vals(type_var_bound)
341+
type_var_bound = maybe_eval_type_data_closure_vals(type_var_bound)
353342
elif not bool(type_var_default):
354343
if i >= len(item):
355344
raise RuntimeError(f"generic {tvar=} must have concrete val")
@@ -372,15 +361,13 @@ def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]):
372361
reified_type_params.append(r)
373362
already_reified_type_params[r.name] = r.val
374363

375-
# replace the tvar in body_builder's global context with the reified val
376-
if tvar.__name__ in body_builder.__globals__:
377-
body_builder.__globals__[tvar.__name__] = r.val
378-
# replace the tvar in body_builder's closure with the reified val
364+
# only in the closure if used in the body
379365
if r.name in body_builder.__code__.co_freevars:
380366
free_i = body_builder.__code__.co_freevars.index(r.name)
381-
assert (
382-
body_builder.__closure__[free_i].cell_contents == tvar
383-
), "typevars don't match"
367+
if body_builder.__closure__[free_i].cell_contents != tvar:
368+
raise RuntimeError(
369+
f"typevars don't match: {id(body_builder.__closure__[free_i].cell_contents)=}, {id(tvar)=}"
370+
)
384371
body_builder.__closure__[free_i].cell_contents = r.val
385372

386373
name_mangled_generics = []
@@ -419,12 +406,9 @@ def func(
419406
func_attrs=None,
420407
function_type=None,
421408
emit=False,
422-
generics=None,
423409
loc=None,
424410
ip=None,
425411
) -> FuncBase:
426-
if generics is None and hasattr(f, "__type_params__") and f.__type_params__:
427-
generics = f.__type_params__
428412
func_ = FuncBase(
429413
body_builder=f,
430414
func_op_ctor=FuncOp.__base__,
@@ -436,7 +420,7 @@ def func(
436420
res_attrs=res_attrs,
437421
func_attrs=func_attrs,
438422
function_type=function_type,
439-
generics=generics,
423+
generics=getattr(f, "__type_params__", None),
440424
loc=loc,
441425
ip=ip,
442426
)

projects/eudsl-python-extras/mlir/extras/dialects/gpu.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from functools import partial
66
from typing import Any, List, Optional, Tuple, Union
77

8-
98
from .func import FuncBase
109
from .. import types as T
1110
from ..meta import (
@@ -25,6 +24,10 @@
2524
get_op_result_or_op_results,
2625
)
2726
from ...dialects.gpu import *
27+
28+
del constant
29+
# constant needs to be below gpu import because it needs to shadow upstream's arith.constant
30+
# noinspection PyUnusedImports
2831
from .arith import constant
2932
from ...ir import (
3033
ArrayAttr,
@@ -439,13 +442,10 @@ def func(
439442
res_attrs=None,
440443
func_attrs=None,
441444
emit=False,
442-
generics=None,
443445
loc=None,
444446
ip=None,
445447
emit_grid=False,
446448
) -> Grid:
447-
if generics is None and hasattr(f, "__type_params__") and f.__type_params__:
448-
generics = f.__type_params__
449449
func_ = GPUFunc(
450450
body_builder=f,
451451
func_op_ctor=GPUFuncOp,
@@ -455,7 +455,7 @@ def func(
455455
arg_attrs=arg_attrs,
456456
res_attrs=res_attrs,
457457
func_attrs=func_attrs,
458-
generics=generics,
458+
generics=getattr(f, "__type_params__", None),
459459
loc=loc,
460460
ip=ip,
461461
)

0 commit comments

Comments
 (0)