@@ -383,9 +383,10 @@ class GraphLayer : public VulkanLayerImpl {
383383 startTime = std::chrono::high_resolution_clock::now ();
384384 }
385385
386- const auto *shaderModuleCreateInfo = findType<VkDataGraphPipelineShaderModuleCreateInfoARM>(
387- createInfo.pNext , VK_STRUCTURE_TYPE_DATA_GRAPH_PIPELINE_SHADER_MODULE_CREATE_INFO_ARM);
388- if (shaderModuleCreateInfo == nullptr ) {
386+ const auto *dataGraphPipelineShaderModuleCreateInfo =
387+ findType<VkDataGraphPipelineShaderModuleCreateInfoARM>(
388+ createInfo.pNext , VK_STRUCTURE_TYPE_DATA_GRAPH_PIPELINE_SHADER_MODULE_CREATE_INFO_ARM);
389+ if (dataGraphPipelineShaderModuleCreateInfo == nullptr ) {
389390 graphLog (Severity::Error) << " Missing shader module create info" << std::endl;
390391 return VK_ERROR_UNKNOWN;
391392 }
@@ -412,8 +413,8 @@ class GraphLayer : public VulkanLayerImpl {
412413 }
413414
414415 // Constants
415- for (uint32_t j = 0 ; j < shaderModuleCreateInfo ->constantCount ; j++) {
416- const auto &constant = shaderModuleCreateInfo ->pConstants [j];
416+ for (uint32_t j = 0 ; j < dataGraphPipelineShaderModuleCreateInfo ->constantCount ; j++) {
417+ const auto &constant = dataGraphPipelineShaderModuleCreateInfo ->pConstants [j];
417418
418419 const auto *graphPipelineConstantTensor =
419420 findType<VkTensorDescriptionARM>(constant.pNext , VK_STRUCTURE_TYPE_TENSOR_DESCRIPTION_ARM);
@@ -425,8 +426,31 @@ class GraphLayer : public VulkanLayerImpl {
425426
426427 graphPipeline->makeConstTensor (constant.id , *graphPipelineConstantTensor, constant.pConstantData );
427428 }
429+ std::shared_ptr<ShaderModule> shaderModule;
430+ if (dataGraphPipelineShaderModuleCreateInfo->module == VK_NULL_HANDLE) {
431+ auto shaderModuleCreateInfo =
432+ findType<VkShaderModuleCreateInfo>(createInfo.pNext , VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO);
433+ if (shaderModuleCreateInfo == nullptr ) {
434+ graphLog (Severity::Error) << " Missing both shader handle and shader create info" << std::endl;
435+ return VK_ERROR_UNKNOWN;
436+ }
428437
429- auto shaderModule = getHandle (deviceHandle, shaderModuleCreateInfo->module );
438+ std::vector<uint32_t > spirvSource = {shaderModuleCreateInfo->pCode ,
439+ shaderModuleCreateInfo->pCode +
440+ shaderModuleCreateInfo->codeSize / sizeof (uint32_t )};
441+ auto isGraph = isGraphSpirv (spirvSource);
442+ if (!isGraph.has_value ()) {
443+ graphLog (Severity::Error) << " Failed to compile spirv code." << std::endl;
444+ return VK_ERROR_UNKNOWN;
445+ } else if (isGraph.value ()) {
446+ shaderModule = std::make_shared<ShaderModule>(shaderModuleCreateInfo);
447+ } else {
448+ graphLog (Severity::Error) << " spirv code does not contain graph." << std::endl;
449+ return VK_ERROR_UNKNOWN;
450+ }
451+ } else {
452+ shaderModule = getHandle (deviceHandle, dataGraphPipelineShaderModuleCreateInfo->module );
453+ }
430454
431455 if (!shaderModule) {
432456 graphLog (Severity::Error) << " Shader module not recognized by Graph layer" << std::endl;
0 commit comments