55from functools import partial
66from typing import Any , List , Optional , Tuple , Union
77
8-
98from .func import FuncBase
109from .. import types as T
1110from ..meta import (
2524 get_op_result_or_op_results ,
2625)
2726from ...dialects .gpu import *
27+
28+ del constant
29+ # constant needs to be below gpu import because it needs to shadow upstream's arith.constant
30+ # noinspection PyUnusedImports
2831from .arith import constant
2932from ...ir import (
3033 ArrayAttr ,
@@ -108,9 +111,9 @@ def z(cls):
108111
109112def thread_id ():
110113 return (
111- block_dim .x * block_dim .y * thread_idx .z
112- + block_dim .x * thread_idx .y
113- + thread_idx .x
114+ block_dim .x * block_dim .y * thread_idx .z
115+ + block_dim .x * thread_idx .y
116+ + thread_idx .x
114117 )
115118
116119
@@ -126,7 +129,7 @@ def set_container_module(module):
126129
127130@register_attribute_builder ("DeviceMappingArrayAttr" )
128131def get_device_mapping_array_attr (
129- mapping : List [Attribute ], context : Optional [Context ] = None
132+ mapping : List [Attribute ], context : Optional [Context ] = None
130133) -> ArrayAttr :
131134 if context is None :
132135 context = Context .current
@@ -177,7 +180,7 @@ def smem_space(int=False):
177180@_cext .register_operation (_Dialect , replace = True )
178181class GPUModuleOp (GPUModuleOp ):
179182 def __init__ (
180- self , sym_name , targets : Optional [List [Attribute ]] = None , * , loc = None , ip = None
183+ self , sym_name , targets : Optional [List [Attribute ]] = None , * , loc = None , ip = None
181184 ):
182185 if targets is None :
183186 targets = []
@@ -188,8 +191,8 @@ def __init__(
188191 sym_name = (
189192 sym_name
190193 if (
191- issubclass (type (sym_name ), Attribute )
192- or not AttrBuilder .contains ("SymbolNameAttr" )
194+ issubclass (type (sym_name ), Attribute )
195+ or not AttrBuilder .contains ("SymbolNameAttr" )
193196 )
194197 else AttrBuilder .get ("SymbolNameAttr" )(sym_name , context = _ods_context )
195198 )
@@ -227,17 +230,17 @@ def __prepare__(cls, name, bases, **kwargs):
227230
228231class GPUFuncOp (GPUFuncOp_ ):
229232 def __init__ (
230- self ,
231- sym_name ,
232- function_type ,
233- * ,
234- sym_visibility = None ,
235- arg_attrs = None ,
236- res_attrs = None ,
237- workgroup_attrib_attrs = None ,
238- private_attrib_attrs = None ,
239- loc = None ,
240- ip = None ,
233+ self ,
234+ sym_name ,
235+ function_type ,
236+ * ,
237+ sym_visibility = None ,
238+ arg_attrs = None ,
239+ res_attrs = None ,
240+ workgroup_attrib_attrs = None ,
241+ private_attrib_attrs = None ,
242+ loc = None ,
243+ ip = None ,
241244 ):
242245 super ().__init__ (
243246 function_type = function_type ,
@@ -255,17 +258,17 @@ def __init__(
255258 self .operation .attributes ["arg_attrs" ] = (
256259 arg_attrs
257260 if (
258- isinstance (arg_attrs , Attribute )
259- or not AttrBuilder .contains ("DictArrayAttr" )
261+ isinstance (arg_attrs , Attribute )
262+ or not AttrBuilder .contains ("DictArrayAttr" )
260263 )
261264 else AttrBuilder .get ("DictArrayAttr" )(arg_attrs , context = _ods_context )
262265 )
263266 if res_attrs is not None :
264267 self .operation .attributes ["res_attrs" ] = (
265268 res_attrs
266269 if (
267- isinstance (res_attrs , Attribute )
268- or not AttrBuilder .contains ("DictArrayAttr" )
270+ isinstance (res_attrs , Attribute )
271+ or not AttrBuilder .contains ("DictArrayAttr" )
269272 )
270273 else AttrBuilder .get ("DictArrayAttr" )(res_attrs , context = _ods_context )
271274 )
@@ -274,23 +277,23 @@ def __init__(
274277 self .operation .attributes ["sym_visibility" ] = (
275278 sym_visibility
276279 if (
277- issubclass (type (sym_visibility ), Attribute )
278- or not AttrBuilder .contains ("StrAttr" )
280+ issubclass (type (sym_visibility ), Attribute )
281+ or not AttrBuilder .contains ("StrAttr" )
279282 )
280283 else AttrBuilder .get ("StrAttr" )(sym_visibility , context = _ods_context )
281284 )
282285
283286
284287class LaunchOp (LaunchOp ):
285288 def __init__ (
286- self ,
287- grid_size : Tuple [Any , Any , Any ],
288- block_size : Tuple [Any , Any , Any ],
289- async_dependencies = None ,
290- dynamic_shared_memory_size : Optional [Value ] = None ,
291- * ,
292- loc = None ,
293- ip = None ,
289+ self ,
290+ grid_size : Tuple [Any , Any , Any ],
291+ block_size : Tuple [Any , Any , Any ],
292+ async_dependencies = None ,
293+ dynamic_shared_memory_size : Optional [Value ] = None ,
294+ * ,
295+ loc = None ,
296+ ip = None ,
294297 ):
295298 _ods_context = get_default_loc_context (loc )
296299 if async_dependencies is None :
@@ -309,13 +312,13 @@ def __init__(
309312
310313
311314def launch_ (
312- grid_size : Tuple [Any , Any , Any ],
313- block_size : Tuple [Any , Any , Any ],
314- async_dependencies = None ,
315- dynamic_shared_memory_size : Optional [Value ] = None ,
316- * ,
317- loc = None ,
318- ip = None ,
315+ grid_size : Tuple [Any , Any , Any ],
316+ block_size : Tuple [Any , Any , Any ],
317+ async_dependencies = None ,
318+ dynamic_shared_memory_size : Optional [Value ] = None ,
319+ * ,
320+ loc = None ,
321+ ip = None ,
319322):
320323 for size in [grid_size , block_size ]:
321324 for i , s in enumerate (size ):
@@ -337,17 +340,17 @@ def launch_(
337340
338341class LaunchFuncOp (LaunchFuncOp ):
339342 def __init__ (
340- self ,
341- kernel : List [str ],
342- grid_size : Tuple [Any , Any , Any ],
343- block_size : Tuple [Any , Any , Any ],
344- kernel_operands : List [Value ] = None ,
345- async_dependencies = None ,
346- dynamic_shared_memory_size : Optional [Value ] = None ,
347- async_object = None ,
348- * ,
349- loc = None ,
350- ip = None ,
343+ self ,
344+ kernel : List [str ],
345+ grid_size : Tuple [Any , Any , Any ],
346+ block_size : Tuple [Any , Any , Any ],
347+ kernel_operands : List [Value ] = None ,
348+ async_dependencies = None ,
349+ dynamic_shared_memory_size : Optional [Value ] = None ,
350+ async_object = None ,
351+ * ,
352+ loc = None ,
353+ ip = None ,
351354 ):
352355 _ods_context = get_default_loc_context (loc )
353356 if async_dependencies is None :
@@ -370,15 +373,15 @@ def __init__(
370373
371374class GPUFunc (FuncBase ):
372375 def __call__ (
373- self ,
374- * kernel_operands : List [Value ],
375- grid_size : Tuple [Any , Any , Any ],
376- block_size : Tuple [Any , Any , Any ],
377- async_dependencies = None ,
378- dynamic_shared_memory_size : Optional [Value ] = None ,
379- stream = None ,
380- loc = None ,
381- ip = None ,
376+ self ,
377+ * kernel_operands : List [Value ],
378+ grid_size : Tuple [Any , Any , Any ],
379+ block_size : Tuple [Any , Any , Any ],
380+ async_dependencies = None ,
381+ dynamic_shared_memory_size : Optional [Value ] = None ,
382+ stream = None ,
383+ loc = None ,
384+ ip = None ,
382385 ):
383386 for size in [grid_size , block_size ]:
384387 for i , s in enumerate (size ):
@@ -432,16 +435,16 @@ def __call__(self, *args, **kwargs):
432435
433436@make_maybe_no_args_decorator
434437def func (
435- f ,
436- * ,
437- sym_visibility = None ,
438- arg_attrs = None ,
439- res_attrs = None ,
440- func_attrs = None ,
441- emit = False ,
442- loc = None ,
443- ip = None ,
444- emit_grid = False ,
438+ f ,
439+ * ,
440+ sym_visibility = None ,
441+ arg_attrs = None ,
442+ res_attrs = None ,
443+ func_attrs = None ,
444+ emit = False ,
445+ loc = None ,
446+ ip = None ,
447+ emit_grid = False ,
445448) -> Grid :
446449 func_ = GPUFunc (
447450 body_builder = f ,
@@ -490,14 +493,14 @@ def wait(async_dependencies: Optional[List[Value]] = None, *, loc=None, ip=None)
490493
491494
492495def alloc (
493- sizes : Union [int , Value ],
494- element_type : Type = None ,
495- async_dependencies = None ,
496- dynamic_sizes = None ,
497- symbol_operands = None ,
498- host_shared = None ,
499- loc = None ,
500- ip = None ,
496+ sizes : Union [int , Value ],
497+ element_type : Type = None ,
498+ async_dependencies = None ,
499+ dynamic_sizes = None ,
500+ symbol_operands = None ,
501+ host_shared = None ,
502+ loc = None ,
503+ ip = None ,
501504):
502505 if symbol_operands is None :
503506 symbol_operands = []
0 commit comments