@@ -2336,23 +2336,96 @@ spv_result_t ValidateCooperativeMatrixLoadStoreKHR(ValidationState_t& _,
23362336 }
23372337
23382338 bool stride_required = false ;
2339+ bool layout_requires_constant_stride = false ;
23392340 uint64_t layout;
23402341 if (_.EvalConstantValUint64 (layout_id, &layout)) {
2342+ const bool is_arm_layout =
2343+ (layout ==
2344+ (uint64_t )spv::CooperativeMatrixLayout::RowBlockedInterleavedARM) ||
2345+ (layout ==
2346+ (uint64_t )spv::CooperativeMatrixLayout::ColumnBlockedInterleavedARM);
2347+
2348+ if (is_arm_layout) {
2349+ if (!_.HasCapability (spv::Capability::CooperativeMatrixLayoutsARM)) {
2350+ return _.diag (SPV_ERROR_INVALID_ID, inst)
2351+ << " Using the RowBlockedInterleavedARM or "
2352+ " ColumnBlockedInterleavedARM MemoryLayout requires the "
2353+ " CooperativeMatrixLayoutsARM capability be declared" ;
2354+ }
2355+ }
2356+
23412357 stride_required =
23422358 (layout == (uint64_t )spv::CooperativeMatrixLayout::RowMajorKHR) ||
2343- (layout == (uint64_t )spv::CooperativeMatrixLayout::ColumnMajorKHR);
2359+ (layout == (uint64_t )spv::CooperativeMatrixLayout::ColumnMajorKHR) ||
2360+ is_arm_layout;
2361+ layout_requires_constant_stride = is_arm_layout;
23442362 }
23452363
23462364 const auto stride_index =
23472365 (inst->opcode () == spv::Op::OpCooperativeMatrixLoadKHR) ? 4u : 3u ;
23482366 if (inst->operands ().size () > stride_index) {
23492367 const auto stride_id = inst->GetOperandAs <uint32_t >(stride_index);
2350- const auto stride = _.FindDef (stride_id);
2351- if (!stride || !_.IsIntScalarType (stride ->type_id ())) {
2368+ const auto stride_inst = _.FindDef (stride_id);
2369+ if (!stride_inst || !_.IsIntScalarType (stride_inst ->type_id ())) {
23522370 return _.diag (SPV_ERROR_INVALID_ID, inst)
23532371 << " Stride operand <id> " << _.getIdName (stride_id)
23542372 << " must be a scalar integer type." ;
23552373 }
2374+ // Check SPV_ARM_cooperative_matrix_layouts constraints
2375+ if (layout_requires_constant_stride &&
2376+ !spvOpcodeIsConstant (stride_inst->opcode ())) {
2377+ return _.diag (SPV_ERROR_INVALID_ID, inst)
2378+ << " MemoryLayout " << layout
2379+ << " requires Stride come from a constant instruction." ;
2380+ }
2381+ if (layout_requires_constant_stride) {
2382+ uint64_t stride;
2383+ if (_.EvalConstantValUint64 (stride_id, &stride)) {
2384+ if ((layout ==
2385+ (uint64_t )
2386+ spv::CooperativeMatrixLayout::RowBlockedInterleavedARM) ||
2387+ (layout ==
2388+ (uint64_t )
2389+ spv::CooperativeMatrixLayout::ColumnBlockedInterleavedARM)) {
2390+ if ((stride != 1 ) && (stride != 2 ) && (stride != 4 )) {
2391+ return _.diag (SPV_ERROR_INVALID_ID, inst)
2392+ << " MemoryLayout " << layout
2393+ << " requires Stride be 1, 2, or 4." ;
2394+ }
2395+ }
2396+ const uint32_t elty_id = matrix_type->GetOperandAs <uint32_t >(1 );
2397+ const uint32_t rows_id = matrix_type->GetOperandAs <uint32_t >(3 );
2398+ const uint32_t cols_id = matrix_type->GetOperandAs <uint32_t >(4 );
2399+ uint64_t rows = 0 , cols = 0 ;
2400+ _.EvalConstantValUint64 (rows_id, &rows);
2401+ _.EvalConstantValUint64 (cols_id, &cols);
2402+ uint32_t sizeof_component_in_bytes = _.GetBitWidth (elty_id) / 8 ;
2403+ uint64_t rows_required_multiple = 4 ;
2404+ uint64_t cols_required_multiple = 16 / sizeof_component_in_bytes;
2405+
2406+ if (layout ==
2407+ (uint64_t )spv::CooperativeMatrixLayout::RowBlockedInterleavedARM) {
2408+ cols_required_multiple *= stride;
2409+ }
2410+ if (layout ==
2411+ (uint64_t )
2412+ spv::CooperativeMatrixLayout::ColumnBlockedInterleavedARM) {
2413+ rows_required_multiple *= stride;
2414+ }
2415+ if ((rows != 0 ) && (rows % rows_required_multiple != 0 )) {
2416+ return _.diag (SPV_ERROR_INVALID_ID, inst)
2417+ << " MemoryLayout " << layout << " with a Stride of " << stride
2418+ << " requires that the number of rows be a multiple of "
2419+ << rows_required_multiple;
2420+ }
2421+ if ((cols != 0 ) && (cols % cols_required_multiple != 0 )) {
2422+ return _.diag (SPV_ERROR_INVALID_ID, inst)
2423+ << " MemoryLayout " << layout << " with a Stride of " << stride
2424+ << " requires that the number of columns be a multiple of "
2425+ << cols_required_multiple;
2426+ }
2427+ }
2428+ }
23562429 } else if (stride_required) {
23572430 return _.diag (SPV_ERROR_INVALID_ID, inst)
23582431 << " MemoryLayout " << layout << " requires a Stride." ;
0 commit comments