Skip to content

Commit 3a67127

Browse files
Lengjunyifinn-fletcher
authored andcommitted
Use chained module to create ShaderModule in pipeline creation
Signed-off-by: Alan Liang <[email protected]> Change-Id: I69281f8f14ae5af4e538aacb194e7601cc01ab01 (cherry picked from commit 2774ac8)
1 parent f02bff0 commit 3a67127

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

graph/graph_layer.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,10 @@ class GraphLayer : public VulkanLayerImpl {
383383
startTime = std::chrono::high_resolution_clock::now();
384384
}
385385

386-
const auto *shaderModuleCreateInfo = findType<VkDataGraphPipelineShaderModuleCreateInfoARM>(
387-
createInfo.pNext, VK_STRUCTURE_TYPE_DATA_GRAPH_PIPELINE_SHADER_MODULE_CREATE_INFO_ARM);
388-
if (shaderModuleCreateInfo == nullptr) {
386+
const auto *dataGraphPipelineShaderModuleCreateInfo =
387+
findType<VkDataGraphPipelineShaderModuleCreateInfoARM>(
388+
createInfo.pNext, VK_STRUCTURE_TYPE_DATA_GRAPH_PIPELINE_SHADER_MODULE_CREATE_INFO_ARM);
389+
if (dataGraphPipelineShaderModuleCreateInfo == nullptr) {
389390
graphLog(Severity::Error) << "Missing shader module create info" << std::endl;
390391
return VK_ERROR_UNKNOWN;
391392
}
@@ -412,8 +413,8 @@ class GraphLayer : public VulkanLayerImpl {
412413
}
413414

414415
// Constants
415-
for (uint32_t j = 0; j < shaderModuleCreateInfo->constantCount; j++) {
416-
const auto &constant = shaderModuleCreateInfo->pConstants[j];
416+
for (uint32_t j = 0; j < dataGraphPipelineShaderModuleCreateInfo->constantCount; j++) {
417+
const auto &constant = dataGraphPipelineShaderModuleCreateInfo->pConstants[j];
417418

418419
const auto *graphPipelineConstantTensor =
419420
findType<VkTensorDescriptionARM>(constant.pNext, VK_STRUCTURE_TYPE_TENSOR_DESCRIPTION_ARM);
@@ -425,8 +426,31 @@ class GraphLayer : public VulkanLayerImpl {
425426

426427
graphPipeline->makeConstTensor(constant.id, *graphPipelineConstantTensor, constant.pConstantData);
427428
}
429+
std::shared_ptr<ShaderModule> shaderModule;
430+
if (dataGraphPipelineShaderModuleCreateInfo->module == VK_NULL_HANDLE) {
431+
auto shaderModuleCreateInfo =
432+
findType<VkShaderModuleCreateInfo>(createInfo.pNext, VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO);
433+
if (shaderModuleCreateInfo == nullptr) {
434+
graphLog(Severity::Error) << "Missing both shader handle and shader create info" << std::endl;
435+
return VK_ERROR_UNKNOWN;
436+
}
428437

429-
auto shaderModule = getHandle(deviceHandle, shaderModuleCreateInfo->module);
438+
std::vector<uint32_t> spirvSource = {shaderModuleCreateInfo->pCode,
439+
shaderModuleCreateInfo->pCode +
440+
shaderModuleCreateInfo->codeSize / sizeof(uint32_t)};
441+
auto isGraph = isGraphSpirv(spirvSource);
442+
if (!isGraph.has_value()) {
443+
graphLog(Severity::Error) << "Failed to compile spirv code." << std::endl;
444+
return VK_ERROR_UNKNOWN;
445+
} else if (isGraph.value()) {
446+
shaderModule = std::make_shared<ShaderModule>(shaderModuleCreateInfo);
447+
} else {
448+
graphLog(Severity::Error) << "spirv code does not contain graph." << std::endl;
449+
return VK_ERROR_UNKNOWN;
450+
}
451+
} else {
452+
shaderModule = getHandle(deviceHandle, dataGraphPipelineShaderModuleCreateInfo->module);
453+
}
430454

431455
if (!shaderModule) {
432456
graphLog(Severity::Error) << "Shader module not recognized by Graph layer" << std::endl;

0 commit comments

Comments
 (0)