Skip to content

Commit 6fc5522

Browse files
committed
Fix AxisInfo rank mismatch for poison tensor pointers
Signed-off-by: Witold Dziurdz <[email protected]>
1 parent 2beb243 commit 6fc5522

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,12 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
276276
getAxisInfo(ub::PoisonOp op,
277277
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
278278
unsigned rank = 1;
279-
if (auto shape = dyn_cast<RankedTensorType>(op.getType()))
279+
if (auto shape = dyn_cast<RankedTensorType>(op.getType())) {
280280
rank = shape.getRank();
281+
} else if (auto ptrTy = dyn_cast<PointerType>(op.getType())) {
282+
if (auto tensorType = dyn_cast<RankedTensorType>(ptrTy.getPointeeType()))
283+
rank = tensorType.getRank();
284+
}
281285

282286
// Poison values are never accessed, thus assume optimistic values.
283287
return AxisInfo(AxisInfo::DimVectorT(rank, kMaxDivisor),

0 commit comments

Comments
 (0)