Skip to content

Commit 77bd12b

Browse files
authored
[eudsl-python-extras] sundry fixes (#210)
* [eudsl-python-extras] sundry fixes
1 parent 352594f commit 77bd12b

File tree

12 files changed

+400
-186
lines changed

12 files changed

+400
-186
lines changed

projects/eudsl-python-extras/mlir/extras/dialects/_shaped_value.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Tuple, Union, List, Any
88

99
from ...dialects.linalg.opdsl.lang.emitter import _is_index_type
10-
from .arith import Scalar
10+
from .arith import ScalarValue
1111
from ...ir import DenseElementsAttr, ShapedType, Type, Value, RankedTensorType
1212

1313
S = ShapedType.get_dynamic_size()
@@ -70,7 +70,7 @@ def dtype(self) -> Type:
7070

7171
@dataclass(frozen=True)
7272
class _Indexer:
73-
indices: Tuple[Union[int, Scalar, slice, "Ellipsis", None]]
73+
indices: Tuple[Union[int, ScalarValue, slice, "Ellipsis", None]]
7474
newaxis_dims: Tuple[int, "Ellipsis"]
7575
in_shape: Tuple[Union[Value, int]]
7676

@@ -80,7 +80,7 @@ def is_constant(self):
8080
def is_full(self):
8181
return all(
8282
isinstance(idx, slice)
83-
# TODO(max): could also work for constant Scalar
83+
# TODO(max): could also work for constant ScalarValue
8484
and all([isinstance(x, int) for x in [idx.start, idx.stop, idx.step]])
8585
and len(range(*idx.indices(self.in_shape[i]))) == self.in_shape[i]
8686
for i, idx in enumerate(self.indices)
@@ -91,7 +91,7 @@ def is_full(self):
9191
def static_offsets(self):
9292
offsets = []
9393
for i in self.indices:
94-
if isinstance(i, (int, Scalar)):
94+
if isinstance(i, (int, ScalarValue)):
9595
offsets.append(int(i))
9696
elif isinstance(i, slice):
9797
offsets.append(int(i.start))
@@ -103,7 +103,7 @@ def static_offsets(self):
103103
def static_sizes(self):
104104
sizes = []
105105
for i in self.indices:
106-
if isinstance(i, (int, Scalar)):
106+
if isinstance(i, (int, ScalarValue)):
107107
sizes.append(1)
108108
elif isinstance(i, slice):
109109
start, stop, step = map(int, (i.start, i.stop, i.step))
@@ -123,7 +123,7 @@ def static_sizes(self):
123123
def static_strides(self):
124124
strides = []
125125
for i in self.indices:
126-
if isinstance(i, (int, Scalar)):
126+
if isinstance(i, (int, ScalarValue)):
127127
strides.append(1)
128128
elif isinstance(i, slice):
129129
strides.append(int(i.step))
@@ -133,13 +133,13 @@ def static_strides(self):
133133

134134

135135
def _indices_to_indexer(
136-
idx: Tuple[Union[Scalar, slice, "Ellipsis", None]], in_shape: Tuple[int]
136+
idx: Tuple[Union[ScalarValue, slice, "Ellipsis", None]], in_shape: Tuple[int]
137137
) -> _Indexer:
138138
"""Processes sequence of index objects and constructs _Indexer with
139139
corresponding indexing tensor and collapse dims (i.e., scatter/gather dims).
140140
141141
Args:
142-
idx: Sequence (list or tuple) of slices, ellipses, Scalar, or Tensors.
142+
idx: Sequence (list or tuple) of slices, ellipses, ScalarValue, or Tensors.
143143
in_shape: The shape of the tensor being indexed into.
144144
145145
Returns:
@@ -150,13 +150,15 @@ def _indices_to_indexer(
150150

151151
in_axis = 0 # Current axis in input.
152152
out_axis = 0 # Current axis in output.
153-
indices: List[Union[Scalar, slice, Ellipsis, None]] = [slice(None)] * len(in_shape)
153+
indices: List[Union[ScalarValue, slice, Ellipsis, None]] = [slice(None)] * len(
154+
in_shape
155+
)
154156
newaxis_dims: List[int] = []
155157

156158
# nb: idx_e <-> idx_element
157159
for idx_i, idx_e in enumerate(idx):
158160
if _is_scalar(idx_e) and _has_index_type(idx_e):
159-
# Handle basic Scalar indexes.
161+
# Handle basic ScalarValue indexes.
160162
indices[in_axis] = idx_e
161163
in_axis += 1
162164
# Handle newaxis (None)
@@ -219,7 +221,7 @@ def _indices_to_indexer(
219221
if _is_constant_index(idx) and _is_constant_scalar(in_shape[i]):
220222
if isinstance(idx, slice):
221223
indices[i] = slice(*idx.indices(int(in_shape[i])))
222-
elif isinstance(idx, Scalar):
224+
elif isinstance(idx, ScalarValue):
223225
indices[i] = int(idx)
224226

225227
return _Indexer(
@@ -234,7 +236,7 @@ def _canonicalize_tuple_index(idx: Tuple[Any], rank: int):
234236
235237
Args:
236238
rank: Rank of tensor.
237-
idx: Index object (Scalar, Tensor, slice, Ellipse, or None).
239+
idx: Index object (ScalarValue, TensorValue, slice, Ellipse, or None).
238240
239241
Returns:
240242
Tuple of index objects with no ellipses.
@@ -282,12 +284,12 @@ def _is_int_arraylike(x):
282284

283285

284286
def _is_scalar(e: Any) -> bool:
285-
"""Checks whether e is a Scalar or can be used to construct a Scalar.
287+
"""Checks whether e is a ScalarValue or can be used to construct a ScalarValue.
286288
287289
Args:
288290
e: Anything
289291
"""
290-
return isinstance(e, Scalar) or isinstance(e, (int, float, bool))
292+
return isinstance(e, ScalarValue) or isinstance(e, (int, float, bool))
291293

292294

293295
def _has_index_type(e: Any) -> bool:
@@ -310,7 +312,7 @@ def _has_index_type(e: Any) -> bool:
310312

311313
def _is_constant_index(e: Any) -> bool:
312314
return (
313-
(isinstance(e, Scalar) and e.is_constant())
315+
(isinstance(e, ScalarValue) and e.is_constant())
314316
or isinstance(e, (int, float, bool))
315317
or (
316318
isinstance(e, slice)
@@ -323,7 +325,7 @@ def _is_constant_index(e: Any) -> bool:
323325

324326
def _is_constant_scalar(e: Any) -> bool:
325327
return (
326-
(isinstance(e, Scalar) and e.is_constant())
328+
(isinstance(e, ScalarValue) and e.is_constant())
327329
or (isinstance(e, (int, float, bool)) and e != ShapedType.get_dynamic_size())
328330
or e is None
329331
)
@@ -349,7 +351,7 @@ def _maybe_compute_size(start, stop, step):
349351
and start.owner.operands[0] == stop.owner.operands[0].owner.operands[0]
350352
and stop.owner.operands[1].is_constant()
351353
and isinstance(step, int)
352-
or (isinstance(step, Scalar) and step.is_constant())
354+
or (isinstance(step, ScalarValue) and step.is_constant())
353355
):
354356
# looks like this
355357
# l = lambda l: l * D

projects/eudsl-python-extras/mlir/extras/dialects/arith.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -147,9 +147,9 @@ class ArithValueMeta(type(Value)):
147147
all three of the following wrappers are equivalent:
148148
149149
```
150-
s1 = Scalar(arith.ConstantOp(f64, 0.0).result)
151-
s2 = Scalar(arith.ConstantOp(f64, 0.0))
152-
s3 = Scalar(0.0)
150+
s1 = ScalarValue(arith.ConstantOp(f64, 0.0).result)
151+
s2 = ScalarValue(arith.ConstantOp(f64, 0.0))
152+
s3 = ScalarValue(0.0)
153153
```
154154
155155
In general the Python object protocol for an object instance is determined
@@ -298,8 +298,8 @@ def _binary_op(
298298
"""Generic for handling infix binary operator dispatch.
299299
300300
Args:
301-
lhs: E.g. Scalar or Tensor below.
302-
rhs: Scalar or Tensor with type matching self.
301+
lhs: E.g. ScalarValue or TensorValue below.
302+
rhs: ScalarValue or TensorValue with type matching self.
303303
op: Binary operator, currently only add, sub, mul
304304
supported.
305305
@@ -496,7 +496,7 @@ def _ne(self, other):
496496
return Value(self) != Value(other)
497497

498498

499-
class Scalar(ArithValue):
499+
class ScalarValue(ArithValue):
500500
"""Value subclass ScalarValue that adds convenience methods
501501
for getting dtype and (possibly) the stored literal value.
502502
@@ -512,7 +512,7 @@ def dtype(self) -> Type:
512512
@cached_property
513513
def literal_value(self) -> Union[int, float, bool]:
514514
if not self.is_constant():
515-
raise ValueError("Can't build literal from non-constant Scalar")
515+
raise ValueError("Can't build literal from non-constant ScalarValue")
516516
return self.owner.opview.literal_value
517517

518518
def __int__(self):
@@ -521,10 +521,10 @@ def __int__(self):
521521
def __float__(self):
522522
return float(self.literal_value)
523523

524-
def coerce(self, other) -> Tuple["Scalar", "Scalar"]:
524+
def coerce(self, other) -> Tuple["ScalarValue", "ScalarValue"]:
525525
if isinstance(other, (int, float, bool)):
526-
other = Scalar(other, dtype=self.dtype)
527-
elif isinstance(other, Scalar) and (
526+
other = ScalarValue(other, dtype=self.dtype)
527+
elif isinstance(other, ScalarValue) and (
528528
_is_index_type(self.type) or _is_index_type(other.type)
529529
):
530530
other = index_cast(other, to=self.type)
@@ -534,7 +534,7 @@ def coerce(self, other) -> Tuple["Scalar", "Scalar"]:
534534

535535

536536
for t in [BF16Type, F16Type, F32Type, F64Type, IndexType, IntegerType, ComplexType]:
537-
register_value_caster(t.static_typeid)(Scalar)
537+
register_value_caster(t.static_typeid)(ScalarValue)
538538

539539

540540
class CanonicalizeFMA(StrictTransformer):

0 commit comments

Comments
 (0)