77from typing import Tuple , Union , List , Any
88
99from ...dialects .linalg .opdsl .lang .emitter import _is_index_type
10- from .arith import Scalar
10+ from .arith import ScalarValue
1111from ...ir import DenseElementsAttr , ShapedType , Type , Value , RankedTensorType
1212
1313S = ShapedType .get_dynamic_size ()
@@ -70,7 +70,7 @@ def dtype(self) -> Type:
7070
7171@dataclass (frozen = True )
7272class _Indexer :
73- indices : Tuple [Union [int , Scalar , slice , "Ellipsis" , None ]]
73+ indices : Tuple [Union [int , ScalarValue , slice , "Ellipsis" , None ]]
7474 newaxis_dims : Tuple [int , "Ellipsis" ]
7575 in_shape : Tuple [Union [Value , int ]]
7676
@@ -80,7 +80,7 @@ def is_constant(self):
8080 def is_full (self ):
8181 return all (
8282 isinstance (idx , slice )
83- # TODO(max): could also work for constant Scalar
83+ # TODO(max): could also work for constant ScalarValue
8484 and all ([isinstance (x , int ) for x in [idx .start , idx .stop , idx .step ]])
8585 and len (range (* idx .indices (self .in_shape [i ]))) == self .in_shape [i ]
8686 for i , idx in enumerate (self .indices )
@@ -91,7 +91,7 @@ def is_full(self):
9191 def static_offsets (self ):
9292 offsets = []
9393 for i in self .indices :
94- if isinstance (i , (int , Scalar )):
94+ if isinstance (i , (int , ScalarValue )):
9595 offsets .append (int (i ))
9696 elif isinstance (i , slice ):
9797 offsets .append (int (i .start ))
@@ -103,7 +103,7 @@ def static_offsets(self):
103103 def static_sizes (self ):
104104 sizes = []
105105 for i in self .indices :
106- if isinstance (i , (int , Scalar )):
106+ if isinstance (i , (int , ScalarValue )):
107107 sizes .append (1 )
108108 elif isinstance (i , slice ):
109109 start , stop , step = map (int , (i .start , i .stop , i .step ))
@@ -123,7 +123,7 @@ def static_sizes(self):
123123 def static_strides (self ):
124124 strides = []
125125 for i in self .indices :
126- if isinstance (i , (int , Scalar )):
126+ if isinstance (i , (int , ScalarValue )):
127127 strides .append (1 )
128128 elif isinstance (i , slice ):
129129 strides .append (int (i .step ))
@@ -133,13 +133,13 @@ def static_strides(self):
133133
134134
135135def _indices_to_indexer (
136- idx : Tuple [Union [Scalar , slice , "Ellipsis" , None ]], in_shape : Tuple [int ]
136+ idx : Tuple [Union [ScalarValue , slice , "Ellipsis" , None ]], in_shape : Tuple [int ]
137137) -> _Indexer :
138138 """Processes sequence of index objects and constructs _Indexer with
139139 corresponding indexing tensor and collapse dims (i.e., scatter/gather dims).
140140
141141 Args:
142- idx: Sequence (list or tuple) of slices, ellipses, Scalar , or Tensors.
142+ idx: Sequence (list or tuple) of slices, ellipses, ScalarValue , or Tensors.
143143 in_shape: The shape of the tensor being indexed into.
144144
145145 Returns:
@@ -150,13 +150,15 @@ def _indices_to_indexer(
150150
151151 in_axis = 0 # Current axis in input.
152152 out_axis = 0 # Current axis in output.
153- indices : List [Union [Scalar , slice , Ellipsis , None ]] = [slice (None )] * len (in_shape )
153+ indices : List [Union [ScalarValue , slice , Ellipsis , None ]] = [slice (None )] * len (
154+ in_shape
155+ )
154156 newaxis_dims : List [int ] = []
155157
156158 # nb: idx_e <-> idx_element
157159 for idx_i , idx_e in enumerate (idx ):
158160 if _is_scalar (idx_e ) and _has_index_type (idx_e ):
159- # Handle basic Scalar indexes.
161+ # Handle basic ScalarValue indexes.
160162 indices [in_axis ] = idx_e
161163 in_axis += 1
162164 # Handle newaxis (None)
@@ -219,7 +221,7 @@ def _indices_to_indexer(
219221 if _is_constant_index (idx ) and _is_constant_scalar (in_shape [i ]):
220222 if isinstance (idx , slice ):
221223 indices [i ] = slice (* idx .indices (int (in_shape [i ])))
222- elif isinstance (idx , Scalar ):
224+ elif isinstance (idx , ScalarValue ):
223225 indices [i ] = int (idx )
224226
225227 return _Indexer (
@@ -234,7 +236,7 @@ def _canonicalize_tuple_index(idx: Tuple[Any], rank: int):
234236
235237 Args:
236238 rank: Rank of tensor.
237- idx: Index object (Scalar, Tensor , slice, Ellipse, or None).
239+ idx: Index object (ScalarValue, TensorValue , slice, Ellipse, or None).
238240
239241 Returns:
240242 Tuple of index objects with no ellipses.
@@ -282,12 +284,12 @@ def _is_int_arraylike(x):
282284
283285
284286def _is_scalar (e : Any ) -> bool :
285- """Checks whether e is a Scalar or can be used to construct a Scalar .
287+ """Checks whether e is a ScalarValue or can be used to construct a ScalarValue .
286288
287289 Args:
288290 e: Anything
289291 """
290- return isinstance (e , Scalar ) or isinstance (e , (int , float , bool ))
292+ return isinstance (e , ScalarValue ) or isinstance (e , (int , float , bool ))
291293
292294
293295def _has_index_type (e : Any ) -> bool :
@@ -310,7 +312,7 @@ def _has_index_type(e: Any) -> bool:
310312
311313def _is_constant_index (e : Any ) -> bool :
312314 return (
313- (isinstance (e , Scalar ) and e .is_constant ())
315+ (isinstance (e , ScalarValue ) and e .is_constant ())
314316 or isinstance (e , (int , float , bool ))
315317 or (
316318 isinstance (e , slice )
@@ -323,7 +325,7 @@ def _is_constant_index(e: Any) -> bool:
323325
324326def _is_constant_scalar (e : Any ) -> bool :
325327 return (
326- (isinstance (e , Scalar ) and e .is_constant ())
328+ (isinstance (e , ScalarValue ) and e .is_constant ())
327329 or (isinstance (e , (int , float , bool )) and e != ShapedType .get_dynamic_size ())
328330 or e is None
329331 )
@@ -349,7 +351,7 @@ def _maybe_compute_size(start, stop, step):
349351 and start .owner .operands [0 ] == stop .owner .operands [0 ].owner .operands [0 ]
350352 and stop .owner .operands [1 ].is_constant ()
351353 and isinstance (step , int )
352- or (isinstance (step , Scalar ) and step .is_constant ())
354+ or (isinstance (step , ScalarValue ) and step .is_constant ())
353355 ):
354356 # looks like this
355357 # l = lambda l: l * D
0 commit comments