Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,12 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
getAxisInfo(ub::PoisonOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
unsigned rank = 1;
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType()))
if (auto shape = dyn_cast<RankedTensorType>(op.getType())) {
rank = shape.getRank();
} else if (auto ptrTy = dyn_cast<PointerType>(op.getType())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a separate branch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR upstream: triton-lang/triton#8824. Summarize wrong rank for %0 = ub.poison : !tt.ptr<tensor<128x64xf16, ...>>

if (auto tensorType = dyn_cast<RankedTensorType>(ptrTy.getPointeeType()))
rank = tensorType.getRank();
}

// Poison values are never accessed, thus assume optimistic values.
return AxisInfo(AxisInfo::DimVectorT(rank, kMaxDivisor),
Expand Down Expand Up @@ -1229,6 +1233,7 @@ void AxisInfo::initDimVectorFromHint(Attribute attr, DimVectorT *vec) {
return rhs;
if (rhs.getRank() == 0)
return lhs;
assert(lhs.getRank() == rhs.getRank() && "Mismatched ranks");
DimVectorT contiguity;
DimVectorT divisibility;
DimVectorT constancy;
Expand Down
16 changes: 16 additions & 0 deletions test/TritonGPU/pipeline-assign-latencies.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
}

// -----
// Test that ub.poison producing a memdesc does not get treated like a tensor
// value in AxisInfo analysis.
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-warps" = 4 : i32} {
tt.func public @minimal_crash(%lb: i32, %ub: i32) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> {
%c1 = arith.constant 1 : i32
%poison = ub.poison : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
%normal = ttg.local_alloc : () -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
%result = scf.for %i = %lb to %ub step %c1 iter_args(%current = %poison) -> !ttg.memdesc<2x2xf16, #shared, #smem, mutable> : i32 {
scf.yield %normal : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
}
tt.return %result : !ttg.memdesc<2x2xf16, #shared, #smem, mutable>
}
}

// -----
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 128], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
Expand Down
29 changes: 29 additions & 0 deletions test/TritonIntelGPU/pipeline-assign-latencies.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: triton-opt %s -split-input-file -allow-unregistered-dialect -tritongpu-assign-latencies=num-stages=3 -canonicalize | FileCheck %s

// Test that ub.poison producing a ptr<tensor> gets correct rank in AxisInfo
// analysis (rank=2 for tensor<128x64>, not rank=1).
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
// CHECK-LABEL: @test_poison_rank
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {
tt.func public @test_poison_rank(%arg0: !tt.ptr<f16>, %lb: i32, %ub: i32) {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c1_i64 = arith.constant 1 : i64
%c128_i64 = arith.constant 128 : i64
%c64_i64 = arith.constant 64 : i64

%0 = ub.poison : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>

%1 = tt.make_tensor_ptr %arg0, [%c128_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>

%result = scf.for %i = %lb to %ub step %c1_i32
iter_args(%ptr = %0) -> !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> : i32 {

%advanced = tt.advance %ptr, [%c0_i32, %c0_i32] : <tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>

scf.yield %advanced : !tt.ptr<tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
}

tt.return
}
}
7 changes: 6 additions & 1 deletion third_party/intel/lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,12 @@ class PoisonOpAxisInfoVisitor final : public AxisInfoVisitorImpl<ub::PoisonOp> {
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
constexpr int64_t largePowerOf2 = int64_t(1) << 32;
// Poison values are never accessed, thus assume optimistic values.
if (auto shape = dyn_cast<mlir::ShapedType>(op.getType())) {
Type type = op.getType();
if (auto ptrTy = dyn_cast<triton::PointerType>(type)) {
type = ptrTy.getPointeeType();
}

if (auto shape = dyn_cast<mlir::ShapedType>(type)) {
unsigned rank = shape.getRank();
return AxisInfo(
/*contiguity=*/AxisInfo::DimVectorT(rank, largePowerOf2),
Expand Down
Loading