99from typing import Optional , List , Union , TypeVar
1010
1111from ..ast .util import copy_func
12- from ..ast .py_type import PyTypeVarObject
12+ from ..ast .py_type import PyTypeVarObject , _Ptr , PyObject
1313from ..meta import op_region_builder
1414from .. import types as T
1515from ..util import get_user_code_loc , make_maybe_no_args_decorator
@@ -141,6 +141,7 @@ def prep_func_types(sig, return_types):
141141class ReifiedTypeParams :
142142 name : str
143143 val : object
144+ type : Optional [type ]
144145
145146
146147class FuncBase :
@@ -232,7 +233,7 @@ def emit(self, *call_args, decl=False, force=False) -> FuncOp:
232233 if self .generics is not None :
233234 for t in self .generics :
234235 if not isinstance (t , ReifiedTypeParams ):
235- raise RuntimeError (f"{ t = } must reified" )
236+ raise RuntimeError (f"{ t = } must be reified" )
236237 locals [t .name ] = t .val
237238 for i , v in enumerate (input_types ):
238239 if isinstance (v , TypeVar ):
@@ -308,35 +309,63 @@ def __getitem__(self, item):
308309 # this also copies the function so that the original body_builder remains "generic" (via its closure)
309310 body_builder = copy_func (self .body_builder )
310311 reified_type_params = []
311- # dumb but whatever
312+
313+ # For "generics" (i.e. typevars) which are dependent on previous generics (identified by the fact that they have vals in their own closures),
314+ # we collect all such previous generics along with the concrete vals (into already_reified_type_params) and then
315+ # evaluate the typevars in the fully-populated closure. Note, in order to get the unevaled typevar bound and default value
316+ # we access them in the PyTypeVarObject C struct itself instead of the API that python provides.
312317 already_reified_type_params = {}
318+
319+ def maybe_eval_type_data_closure_vals (unevaled_type_data : _Ptr [PyObject ]):
320+ assert type (unevaled_type_data ) == _Ptr [PyObject ]
321+ unevaled_type_data = unevaled_type_data .contents .into_object ()
322+ cvrs = inspect .getclosurevars (unevaled_type_data ).nonlocals
323+ if len (cvrs ):
324+ for k , v in cvrs .items ():
325+ if not isinstance (v , TypeVar ):
326+ continue
327+ if k not in already_reified_type_params :
328+ raise RuntimeError (
329+ f"typevar { k } not reified prior to evaluating dependent typevar { t } "
330+ )
331+ cvrs [k ] = already_reified_type_params [k ]
332+ unevaled_type_data = copy_func (unevaled_type_data , cvrs )
333+ return unevaled_type_data ()
334+
313335 generics = copy .deepcopy (self .generics )
314336 for i , t in enumerate (generics ):
337+ type_var_default = None
315338 if sys .version_info >= (3 , 12 ):
316- type_var_bound = PyTypeVarObject .from_object (t ).bound
339+ type_var = PyTypeVarObject .from_object (t )
340+ type_var_bound = type_var .bound
341+ if sys .version_info >= (3 , 13 ) and t .has_default ():
342+ type_var_default = type_var .default_value
317343 else :
318344 type_var_bound = t .__bound__
319- if type_var_bound :
345+
346+ if bool (type_var_bound ):
320347 # before 3.12 typevar was just a python class
321348 # https://github.com/python/cpython/blob/3.11/Lib/typing.py#L966
322- if sys .version_info < (3 , 12 ):
323- type_var_bound = lambda : type_var_bound
349+ if sys .version_info >= (3 , 12 ):
350+ type_var_bound = maybe_eval_type_data_closure_vals (type_var_bound )
351+ elif not bool (type_var_default ):
352+ if i >= len (item ):
353+ raise RuntimeError (f"generic { t } must have concrete val" )
354+ if isinstance (item [i ], Type ):
355+ type_var_bound = "type"
324356 else :
325- type_var_bound = type_var_bound .contents .into_object ()
326- cvrs = inspect .getclosurevars (type_var_bound ).nonlocals
327- if len (cvrs ):
328- for k , v in cvrs .items ():
329- if not isinstance (v , TypeVar ):
330- continue
331- if k not in already_reified_type_params :
332- raise RuntimeError (
333- f"typevar { k } not reified prior to evaluating dependent typevar { t } "
334- )
335- cvrs [k ] = already_reified_type_params [k ]
336- type_var_bound = copy_func (type_var_bound , cvrs )
337- r = ReifiedTypeParams (t .__name__ , type_var_bound ())
357+ type_var_bound = type (item [i ]).__name__
358+
359+ if bool (type_var_default ):
360+ type_var_default = maybe_eval_type_data_closure_vals (type_var_default )
361+ type_var_bound = type (type_var_default ).__name__
362+ val = type_var_default
338363 else :
339- r = ReifiedTypeParams (t .__name__ , item [i ])
364+ if i >= len (item ):
365+ raise RuntimeError (f"generic { t } must have concrete val" )
366+ val = item [i ]
367+
368+ r = ReifiedTypeParams (t .__name__ , val , type_var_bound )
340369
341370 reified_type_params .append (r )
342371 already_reified_type_params [r .name ] = r .val
@@ -357,6 +386,11 @@ def __getitem__(self, item):
357386 self .call_op_ctor ,
358387 return_types = self .return_types ,
359388 sym_visibility = self .sym_visibility ,
389+ sym_name = (
390+ self .func_name
391+ + "_"
392+ + "_" .join ([f"{ r .type } _{ r .val } " for r in reified_type_params ])
393+ ),
360394 arg_attrs = self .arg_attrs ,
361395 res_attrs = self .res_attrs ,
362396 func_attrs = self .func_attrs ,
0 commit comments