Source code for simvx.graphics.gpu.pipeline_compute
"""Compute pipeline factory.
Each compute pass previously inlined the same ~50-line block of
``vkCreatePipelineLayout`` + ``vkCreateComputePipelines`` boilerplate.
:func:`create_compute_pipeline` collapses that to a single call returning
``(pipeline, pipeline_layout, shader_module)`` ready for storage and
later cleanup.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
import vulkan as vk
from ..materials.shader_compiler import compile_shader
from .pipeline import create_shader_module
log = logging.getLogger(__name__)
__all__ = ["create_compute_pipeline"]
[docs]
def create_compute_pipeline(
device: Any,
shader_path: Path,
descriptor_layouts: list[Any],
push_constant_size: int = 0,
*,
entry_point: bytes = b"main",
) -> tuple[Any, Any, Any]:
"""Compile a compute shader and build (pipeline, pipeline_layout, shader_module).
Args:
device: Vulkan logical device.
shader_path: Path to the .comp source file.
descriptor_layouts: Descriptor set layouts to bind to slots 0..N-1.
push_constant_size: Bytes of push-constant range exposed to the shader (0 = none).
entry_point: Shader entry point name (defaults to ``main``).
Returns:
``(pipeline, pipeline_layout, shader_module)`` — caller owns all three.
"""
module = create_shader_module(device, compile_shader(shader_path))
ffi = vk.ffi
layout_ci = ffi.new("VkPipelineLayoutCreateInfo*")
layout_ci.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO
n_sets = len(descriptor_layouts)
if n_sets:
set_layouts = ffi.new(f"VkDescriptorSetLayout[{n_sets}]", descriptor_layouts)
layout_ci.setLayoutCount = n_sets
layout_ci.pSetLayouts = set_layouts
else:
layout_ci.setLayoutCount = 0
if push_constant_size > 0:
push_range = ffi.new("VkPushConstantRange*")
push_range.stageFlags = vk.VK_SHADER_STAGE_COMPUTE_BIT
push_range.offset = 0
push_range.size = push_constant_size
layout_ci.pushConstantRangeCount = 1
layout_ci.pPushConstantRanges = push_range
layout_out = ffi.new("VkPipelineLayout*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreatePipelineLayout, device, layout_ci, ffi.NULL, layout_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreatePipelineLayout (compute) failed: {result}")
pipeline_layout = layout_out[0]
main_name = ffi.new("char[]", entry_point)
stage = ffi.new("VkPipelineShaderStageCreateInfo*")
stage.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO
stage.stage = vk.VK_SHADER_STAGE_COMPUTE_BIT
stage.module = module
stage.pName = main_name
pipeline_ci = ffi.new("VkComputePipelineCreateInfo*")
pipeline_ci.sType = vk.VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO
pipeline_ci.layout = pipeline_layout
pipeline_ci.stage = stage[0]
pipeline_out = ffi.new("VkPipeline*")
result = vk._vulkan._callApi(
vk._vulkan.lib.vkCreateComputePipelines, device, ffi.NULL, 1, pipeline_ci, ffi.NULL, pipeline_out,
)
if result != vk.VK_SUCCESS:
raise RuntimeError(f"vkCreateComputePipelines failed: {result}")
pipeline = pipeline_out[0]
log.debug("Compute pipeline created: %s", shader_path.name)
return pipeline, pipeline_layout, module