Skip to content

Commit 352594f

Browse files
authored
[eudsl-python-extras] name mangle generics "instantiations" (#207)
1 parent 367af39 commit 352594f

File tree

6 files changed

+351
-33
lines changed

6 files changed

+351
-33
lines changed

projects/eudsl-python-extras/mlir/extras/ast/py_type.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,12 +203,17 @@ def from_object(cls, obj) -> Self:
203203
return cls.from_address(address(obj))
204204

205205

206+
# https://github.com/python/cpython/blob/9648eed33f5fca0fcb8802fe0be8d35907bc33e3/Objects/typevarobject.c#L19-L31
206207
class PyTypeVarObject(Structure):
207208
_fields_ = _py_object_fields + [
208209
("ob_size", c_ssize_t),
209210
("name", _Ptr[PyObject]),
210-
# not sure why but this is the only thing that works but that's fine because it's the only thing we need
211211
("bound", _Ptr[PyObject]),
212+
("evaluate_bound", _Ptr[PyObject]),
213+
("constraints", _Ptr[PyObject]),
214+
("evaluate_constraints", _Ptr[PyObject]),
215+
("default_value", _Ptr[PyObject]),
216+
("evaluate_default", _Ptr[PyObject]),
212217
]
213218

214219
@classmethod

projects/eudsl-python-extras/mlir/extras/ast/util.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import ast
55
import functools
66
import inspect
7+
import io
78
import types
89
from opcode import opmap
910
from textwrap import dedent
1011
from typing import Dict
1112

1213
from bytecode import ConcreteBytecode
1314
from cloudpickle import cloudpickle
15+
from ...ir import Type
1416

1517

1618
def set_lineno(node, n=1):
@@ -117,6 +119,17 @@ def replace_closure(code, new_closure: Dict):
117119
return new_code, closure
118120

119121

122+
def unpickle_mlir_type(v):
123+
return Type.parse(v)
124+
125+
126+
class MLIRTypePickler(cloudpickle.Pickler):
127+
def reducer_override(self, obj):
128+
if isinstance(obj, Type):
129+
return unpickle_mlir_type, (str(obj),)
130+
return super().reducer_override(obj)
131+
132+
120133
# Based on http://stackoverflow.com/a/6528148/190597 (Glenn Maynard);
121134
# potentially more complete approach https://stackoverflow.com/a/56901529/9045206
122135
def copy_func(f, new_closure: Dict = None):
@@ -125,7 +138,10 @@ def copy_func(f, new_closure: Dict = None):
125138
else:
126139
# see https://github.com/cloudpipe/cloudpickle/blob/f111f7ab6d302e9b1e2a568d0e4c574895db6a6e/cloudpickle/cloudpickle.py#L813
127140
# for how this trick is accomplished (dill and pickle both fail to pickle eg generic typevars)
128-
closure = cloudpickle.loads(cloudpickle.dumps(f.__closure__))
141+
with io.BytesIO() as file:
142+
cp = MLIRTypePickler(file)
143+
cp.dump(f.__closure__)
144+
closure = cloudpickle.loads(file.getvalue())
129145
code = f.__code__
130146

131147
g = types.FunctionType(

projects/eudsl-python-extras/mlir/extras/dialects/func.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import Optional, List, Union, TypeVar
1010

1111
from ..ast.util import copy_func
12-
from ..ast.py_type import PyTypeVarObject
12+
from ..ast.py_type import PyTypeVarObject, _Ptr, PyObject
1313
from ..meta import op_region_builder
1414
from .. import types as T
1515
from ..util import get_user_code_loc, make_maybe_no_args_decorator
@@ -141,6 +141,7 @@ def prep_func_types(sig, return_types):
141141
class ReifiedTypeParams:
142142
name: str
143143
val: object
144+
type: Optional[type]
144145

145146

146147
class FuncBase:
@@ -232,7 +233,7 @@ def emit(self, *call_args, decl=False, force=False) -> FuncOp:
232233
if self.generics is not None:
233234
for t in self.generics:
234235
if not isinstance(t, ReifiedTypeParams):
235-
raise RuntimeError(f"{t=} must reified")
236+
raise RuntimeError(f"{t=} must be reified")
236237
locals[t.name] = t.val
237238
for i, v in enumerate(input_types):
238239
if isinstance(v, TypeVar):
@@ -308,35 +309,63 @@ def __getitem__(self, item):
308309
# this also copies the function so that the original body_builder remains "generic" (via its closure)
309310
body_builder = copy_func(self.body_builder)
310311
reified_type_params = []
311-
# dumb but whatever
312+
313+
# For "generics" (i.e. typevars) which are dependent on previous generics (identified by the fact that they have vals in their own closures),
314+
# we collect all such previous generics along with the concrete vals (into already_reified_type_params) and then
315+
# evaluate the typevars in the fully-populated closure. Note, in order to get the unevaled typevar bound and default value
316+
# we access them in the PyTypeVarObject C struct itself instead of the API that python provides.
312317
already_reified_type_params = {}
318+
319+
def maybe_eval_type_data_closure_vals(unevaled_type_data: _Ptr[PyObject]):
320+
assert type(unevaled_type_data) == _Ptr[PyObject]
321+
unevaled_type_data = unevaled_type_data.contents.into_object()
322+
cvrs = inspect.getclosurevars(unevaled_type_data).nonlocals
323+
if len(cvrs):
324+
for k, v in cvrs.items():
325+
if not isinstance(v, TypeVar):
326+
continue
327+
if k not in already_reified_type_params:
328+
raise RuntimeError(
329+
f"typevar {k} not reified prior to evaluating dependent typevar {t}"
330+
)
331+
cvrs[k] = already_reified_type_params[k]
332+
unevaled_type_data = copy_func(unevaled_type_data, cvrs)
333+
return unevaled_type_data()
334+
313335
generics = copy.deepcopy(self.generics)
314336
for i, t in enumerate(generics):
337+
type_var_default = None
315338
if sys.version_info >= (3, 12):
316-
type_var_bound = PyTypeVarObject.from_object(t).bound
339+
type_var = PyTypeVarObject.from_object(t)
340+
type_var_bound = type_var.bound
341+
if sys.version_info >= (3, 13) and t.has_default():
342+
type_var_default = type_var.default_value
317343
else:
318344
type_var_bound = t.__bound__
319-
if type_var_bound:
345+
346+
if bool(type_var_bound):
320347
# before 3.12 typevar was just a python class
321348
# https://github.com/python/cpython/blob/3.11/Lib/typing.py#L966
322-
if sys.version_info < (3, 12):
323-
type_var_bound = lambda: type_var_bound
349+
if sys.version_info >= (3, 12):
350+
type_var_bound = maybe_eval_type_data_closure_vals(type_var_bound)
351+
elif not bool(type_var_default):
352+
if i >= len(item):
353+
raise RuntimeError(f"generic {t} must have concrete val")
354+
if isinstance(item[i], Type):
355+
type_var_bound = "type"
324356
else:
325-
type_var_bound = type_var_bound.contents.into_object()
326-
cvrs = inspect.getclosurevars(type_var_bound).nonlocals
327-
if len(cvrs):
328-
for k, v in cvrs.items():
329-
if not isinstance(v, TypeVar):
330-
continue
331-
if k not in already_reified_type_params:
332-
raise RuntimeError(
333-
f"typevar {k} not reified prior to evaluating dependent typevar {t}"
334-
)
335-
cvrs[k] = already_reified_type_params[k]
336-
type_var_bound = copy_func(type_var_bound, cvrs)
337-
r = ReifiedTypeParams(t.__name__, type_var_bound())
357+
type_var_bound = type(item[i]).__name__
358+
359+
if bool(type_var_default):
360+
type_var_default = maybe_eval_type_data_closure_vals(type_var_default)
361+
type_var_bound = type(type_var_default).__name__
362+
val = type_var_default
338363
else:
339-
r = ReifiedTypeParams(t.__name__, item[i])
364+
if i >= len(item):
365+
raise RuntimeError(f"generic {t} must have concrete val")
366+
val = item[i]
367+
368+
r = ReifiedTypeParams(t.__name__, val, type_var_bound)
340369

341370
reified_type_params.append(r)
342371
already_reified_type_params[r.name] = r.val
@@ -357,6 +386,11 @@ def __getitem__(self, item):
357386
self.call_op_ctor,
358387
return_types=self.return_types,
359388
sym_visibility=self.sym_visibility,
389+
sym_name=(
390+
self.func_name
391+
+ "_"
392+
+ "_".join([f"{r.type}_{r.val}" for r in reified_type_params])
393+
),
360394
arg_attrs=self.arg_attrs,
361395
res_attrs=self.res_attrs,
362396
func_attrs=self.func_attrs,

projects/eudsl-python-extras/mlir/extras/dialects/linalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,10 @@ def negf(I, O, *, loc=None, ip=None):
300300
return linalg.negf(I, loc=loc, ip=ip, outs=[O])
301301

302302

303-
def pooling_nchw_max(I, K, O, *, loc=None, ip=None):
304-
return linalg.pooling_nchw_max(I, K, loc=loc, ip=ip, outs=[O])
303+
def pooling_nchw_max(I, K, O, *, strides, dilations, loc=None, ip=None):
304+
return linalg.pooling_nchw_max(
305+
I, K, strides=strides, dilations=dilations, loc=loc, ip=ip, outs=[O]
306+
)
305307

306308

307309
def pooling_nchw_sum(I, K, O, *, loc=None, ip=None):

projects/eudsl-python-extras/mlir/extras/runtime/refbackend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def invoke(*args):
178178

179179
return invoke
180180

181+
def __getitem__(self, item):
182+
return getattr(self, item)
183+
181184

182185
# A return consumer is a trampoline to a python function that will store/capture the return from the return;
183186
# this is done because you can't return structs etc from C APIs.

0 commit comments

Comments
 (0)