Skip to content

Commit 7d49a20

Browse files
committed
[MLIR][XeVM] Improve matrix ops lowering
1 parent 2024d67 commit 7d49a20

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -609,16 +609,25 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
609609
Value baseAddr32 = adaptor.getMemDesc();
610610
Value mdescVal = op.getMemDesc();
611611
// Load result or Store value Type can be vector or scalar.
612-
Value data;
613-
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
614-
data = op.getResult();
615-
else
616-
data = adaptor.getData();
617-
VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType());
612+
Type dataTy;
613+
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
614+
Type resType = op.getResult().getType();
615+
// Some transforms may leave unit dimension in the 2D vector, adaptors do
616+
// not catch it for results.
617+
if (auto vecType = dyn_cast<VectorType>(resType)) {
618+
auto nonUnitDims = llvm::count_if(vecType.getShape(),
619+
[](int64_t d) { return d != 1; });
620+
assert(nonUnitDims <= 1 &&
621+
"Expected either 1D vector or nD with unit dimensions");
622+
resType = VectorType::get({vecType.getNumElements()},
623+
vecType.getElementType());
624+
}
625+
dataTy = resType;
626+
} else
627+
dataTy = adaptor.getData().getType();
628+
VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
618629
if (!valOrResVecTy)
619-
valOrResVecTy = VectorType::get(1, data.getType());
620-
if (valOrResVecTy.getShape().size() != 1)
621-
return rewriter.notifyMatchFailure(op, "Expected 1D data vector.");
630+
valOrResVecTy = VectorType::get(1, dataTy);
622631

623632
int64_t elemBitWidth =
624633
valOrResVecTy.getElementType().getIntOrFloatBitWidth();
@@ -1109,6 +1118,7 @@ struct ConvertXeGPUToXeVMPass
11091118
};
11101119
typeConverter.addSourceMaterialization(
11111120
singleElementVectorMaterializationCast);
1121+
typeConverter.addSourceMaterialization(vectorMaterializationCast);
11121122
typeConverter.addTargetMaterialization(memrefMaterializationCast);
11131123
typeConverter.addTargetMaterialization(ui32MaterializationCast);
11141124
typeConverter.addTargetMaterialization(ui64MaterializationCast);

mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
3030
gpu.func @load_store_matrix_plain_2d_input(%arg0: memref<8192xi8, 3>) -> f32 {
3131
%c0 = arith.constant 0 : index
3232
%view = memref.view %arg0[%c0][]: memref<8192xi8, 3> to memref<64x32xf32, 3>
33-
33+
3434
%subview = memref.subview %view[32, 0] [32, 32] [1, 1] : memref<64x32xf32, 3> to memref<32x32xf32, strided<[32, 1], offset: 1024>, 3>
3535

3636
%0 = xegpu.create_mem_desc %subview : memref<32x32xf32, strided<[32, 1], offset: 1024>, 3> -> !xegpu.mem_desc<32x32xf32>
@@ -43,7 +43,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
4343
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
4444

4545
%tid_x = gpu.thread_id x
46-
46+
4747
%1 = xegpu.load_matrix %0[%c0, %tid_x]: !xegpu.mem_desc<32x32xf32>, index, index -> f32
4848

4949
//CHECK: llvm.store {{.*}}, {{.*}} : f32, !llvm.ptr<3>
@@ -81,15 +81,15 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
8181
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
8282

8383
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> f16
84-
84+
8585

8686
%tid_x = gpu.thread_id x
8787
%c13 = arith.constant 13 : index
8888
%1 = xegpu.load_matrix %0[%c13, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> f16
8989

9090
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
91-
92-
xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
91+
92+
xegpu.store_matrix %1, %0[%c13, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index
9393
gpu.return %1: f16
9494
}
9595

@@ -102,12 +102,12 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
102102
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %arg0 : memref<4096xi8, 3> -> index
103103
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
104104
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
105-
105+
106106
//CHECK: %[[tid_x:.*]] = gpu.thread_id x
107107
//CHECK: %[[c19:.*]] = arith.constant 19 : index
108108
%tid_x = gpu.thread_id x
109109
%c19 = arith.constant 19: index
110-
110+
111111
//CHECK: %[[c16:.*]] = arith.constant 16 : index
112112
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
113113
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
@@ -127,10 +127,10 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
127127
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
128128
//CHECK: %[[loaded:.*]] = llvm.load {{.*}} : !llvm.ptr<3> -> f16
129129
%1 = xegpu.load_matrix %0[%c19, %tid_x]: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> f16
130-
130+
131131
//CHECK: llvm.store %[[loaded]], {{.*}} : f16, !llvm.ptr<3>
132132
xegpu.store_matrix %1, %0[%c19, %tid_x]: f16, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
133-
133+
134134
//CHECK: gpu.return %[[loaded]] : f16
135135
gpu.return %1: f16
136136
}
@@ -161,7 +161,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
161161
//CHECK: %[[add3:.*]] = arith.addi %[[mul3]], %[[add2]] : index
162162

163163
//CHECK: %[[loaded:.*]] = llvm.load {{.*}}: !llvm.ptr<3> -> vector<8xf16>
164-
164+
165165
%tid_x = gpu.thread_id x
166166
%c16 = arith.constant 16 : index
167167
%1 = xegpu.load_matrix %0[%c16, %tid_x] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<stride = [1, 32], block = [16, 16]>>, index, index -> vector<8xf16>
@@ -172,7 +172,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
172172
gpu.return %1: vector<8xf16>
173173
}
174174

175-
175+
176176
// e.g. for mem_desc<32x64xf16, @block=[16, 16]>
177177
// its memory layout tuple is ([2,4,16,16],[1024,256,16,1])
178178
//CHECK-LABEL: load_store_matrix_blocked_subgroupblockio
@@ -214,11 +214,22 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
214214
%1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>
215215

216216
//CHECK: %[[storeDataI16:.*]] = vector.bitcast %[[loaded]] : vector<8xf16> to vector<8xi16>
217-
//CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
217+
//CHECK: xevm.blockstore %[[ptr]], %[[storeDataI16]] : (!llvm.ptr<3>, vector<8xi16>)
218218

219219
xegpu.store_matrix %1, %0[%c16, %c48] {subgroup_block_io}: vector<8xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index
220220

221221
gpu.return %1: vector<8xf16>
222222
}
223223

224+
gpu.func @matrix_vector_materialization(%matrixdesc : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>) {
225+
// CHECK: %[[XEVM_VECTOR:.*]] = llvm.load %{{.*}} : !llvm.ptr<3> -> vector<16xf16>
226+
// CHECK: %[[SOURCE_MATERIALIZE:.*]] = vector.shape_cast %[[XEVM_VECTOR]] : vector<16xf16> to vector<1x16xf16>
227+
// CHECK: %[[XEGPU_VECTOR:.*]] = arith.addf %[[SOURCE_MATERIALIZE]], %[[SOURCE_MATERIALIZE]] : vector<1x16xf16>
228+
// CHECK: %[[TARGET_MATERIALIZE:.*]] = vector.shape_cast %[[XEGPU_VECTOR]] : vector<1x16xf16> to vector<16xf16>
229+
// CHECK: llvm.store %[[TARGET_MATERIALIZE]], %{{.*}} : vector<16xf16>, !llvm.ptr<3>
230+
%loaded = xegpu.load_matrix %matrixdesc[16,0] : !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>> -> vector<1x16xf16>
231+
%loaded_2 = arith.addf %loaded, %loaded : vector<1x16xf16>
232+
xegpu.store_matrix %loaded_2, %matrixdesc[16,0] : vector<1x16xf16>, !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
233+
gpu.return
234+
}
224235
}

0 commit comments

Comments
 (0)