Skip to content

Commit 3b94e14

Browse files
authored
spirv-val: add support for SPV_ARM_cooperative_matrix_layouts (#6408)
Signed-off-by: Kevin Petit <[email protected]>
1 parent bb65141 commit 3b94e14

File tree

2 files changed

+302
-19
lines changed

2 files changed

+302
-19
lines changed

source/val/validate_memory.cpp

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2336,23 +2336,96 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
23362336
}
23372337

23382338
bool stride_required = false;
2339+
bool layout_requires_constant_stride = false;
23392340
uint64_t layout;
23402341
if (_.EvalConstantValUint64(layout_id, &layout)) {
2342+
const bool is_arm_layout =
2343+
(layout ==
2344+
(uint64_t)spv::CooperativeMatrixLayout::RowBlockedInterleavedARM) ||
2345+
(layout ==
2346+
(uint64_t)spv::CooperativeMatrixLayout::ColumnBlockedInterleavedARM);
2347+
2348+
if (is_arm_layout) {
2349+
if (!_.HasCapability(spv::Capability::CooperativeMatrixLayoutsARM)) {
2350+
return _.diag(SPV_ERROR_INVALID_ID, inst)
2351+
<< "Using the RowBlockedInterleavedARM or "
2352+
"ColumnBlockedInterleavedARM MemoryLayout requires the "
2353+
"CooperativeMatrixLayoutsARM capability be declared";
2354+
}
2355+
}
2356+
23412357
stride_required =
23422358
(layout == (uint64_t)spv::CooperativeMatrixLayout::RowMajorKHR) ||
2343-
(layout == (uint64_t)spv::CooperativeMatrixLayout::ColumnMajorKHR);
2359+
(layout == (uint64_t)spv::CooperativeMatrixLayout::ColumnMajorKHR) ||
2360+
is_arm_layout;
2361+
layout_requires_constant_stride = is_arm_layout;
23442362
}
23452363

23462364
const auto stride_index =
23472365
(inst->opcode() == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u;
23482366
if (inst->operands().size() > stride_index) {
23492367
const auto stride_id = inst->GetOperandAs<uint32_t>(stride_index);
2350-
const auto stride = _.FindDef(stride_id);
2351-
if (!stride || !_.IsIntScalarType(stride->type_id())) {
2368+
const auto stride_inst = _.FindDef(stride_id);
2369+
if (!stride_inst || !_.IsIntScalarType(stride_inst->type_id())) {
23522370
return _.diag(SPV_ERROR_INVALID_ID, inst)
23532371
<< "Stride operand <id> " << _.getIdName(stride_id)
23542372
<< " must be a scalar integer type.";
23552373
}
2374+
// Check SPV_ARM_cooperative_matrix_layouts constraints
2375+
if (layout_requires_constant_stride &&
2376+
!spvOpcodeIsConstant(stride_inst->opcode())) {
2377+
return _.diag(SPV_ERROR_INVALID_ID, inst)
2378+
<< "MemoryLayout " << layout
2379+
<< " requires Stride come from a constant instruction.";
2380+
}
2381+
if (layout_requires_constant_stride) {
2382+
uint64_t stride;
2383+
if (_.EvalConstantValUint64(stride_id, &stride)) {
2384+
if ((layout ==
2385+
(uint64_t)
2386+
spv::CooperativeMatrixLayout::RowBlockedInterleavedARM) ||
2387+
(layout ==
2388+
(uint64_t)
2389+
spv::CooperativeMatrixLayout::ColumnBlockedInterleavedARM)) {
2390+
if ((stride != 1) && (stride != 2) && (stride != 4)) {
2391+
return _.diag(SPV_ERROR_INVALID_ID, inst)
2392+
<< "MemoryLayout " << layout
2393+
<< " requires Stride be 1, 2, or 4.";
2394+
}
2395+
}
2396+
const uint32_t elty_id = matrix_type->GetOperandAs<uint32_t>(1);
2397+
const uint32_t rows_id = matrix_type->GetOperandAs<uint32_t>(3);
2398+
const uint32_t cols_id = matrix_type->GetOperandAs<uint32_t>(4);
2399+
uint64_t rows = 0, cols = 0;
2400+
_.EvalConstantValUint64(rows_id, &rows);
2401+
_.EvalConstantValUint64(cols_id, &cols);
2402+
uint32_t sizeof_component_in_bytes = _.GetBitWidth(elty_id) / 8;
2403+
uint64_t rows_required_multiple = 4;
2404+
uint64_t cols_required_multiple = 16 / sizeof_component_in_bytes;
2405+
2406+
if (layout ==
2407+
(uint64_t)spv::CooperativeMatrixLayout::RowBlockedInterleavedARM) {
2408+
cols_required_multiple *= stride;
2409+
}
2410+
if (layout ==
2411+
(uint64_t)
2412+
spv::CooperativeMatrixLayout::ColumnBlockedInterleavedARM) {
2413+
rows_required_multiple *= stride;
2414+
}
2415+
if ((rows != 0) && (rows % rows_required_multiple != 0)) {
2416+
return _.diag(SPV_ERROR_INVALID_ID, inst)
2417+
<< "MemoryLayout " << layout << " with a Stride of " << stride
2418+
<< " requires that the number of rows be a multiple of "
2419+
<< rows_required_multiple;
2420+
}
2421+
if ((cols != 0) && (cols % cols_required_multiple != 0)) {
2422+
return _.diag(SPV_ERROR_INVALID_ID, inst)
2423+
<< "MemoryLayout " << layout << " with a Stride of " << stride
2424+
<< " requires that the number of columns be a multiple of "
2425+
<< cols_required_multiple;
2426+
}
2427+
}
2428+
}
23562429
} else if (stride_required) {
23572430
return _.diag(SPV_ERROR_INVALID_ID, inst)
23582431
<< "MemoryLayout " << layout << " requires a Stride.";

0 commit comments

Comments
 (0)