Source code for simvx.graphics.gpu.pipeline_compute

"""Compute pipeline factory.

Each compute pass previously inlined the same ~50-line block of
``vkCreatePipelineLayout`` + ``vkCreateComputePipelines`` boilerplate.
:func:`create_compute_pipeline` collapses that to a single call returning
``(pipeline, pipeline_layout, shader_module)`` ready for storage and
later cleanup.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any

import vulkan as vk

from ..materials.shader_compiler import compile_shader
from .pipeline import create_shader_module

log = logging.getLogger(__name__)

__all__ = ["create_compute_pipeline"]


[docs] def create_compute_pipeline( device: Any, shader_path: Path, descriptor_layouts: list[Any], push_constant_size: int = 0, *, entry_point: bytes = b"main", ) -> tuple[Any, Any, Any]: """Compile a compute shader and build (pipeline, pipeline_layout, shader_module). Args: device: Vulkan logical device. shader_path: Path to the .comp source file. descriptor_layouts: Descriptor set layouts to bind to slots 0..N-1. push_constant_size: Bytes of push-constant range exposed to the shader (0 = none). entry_point: Shader entry point name (defaults to ``main``). Returns: ``(pipeline, pipeline_layout, shader_module)`` — caller owns all three. """ module = create_shader_module(device, compile_shader(shader_path)) ffi = vk.ffi layout_ci = ffi.new("VkPipelineLayoutCreateInfo*") layout_ci.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO n_sets = len(descriptor_layouts) if n_sets: set_layouts = ffi.new(f"VkDescriptorSetLayout[{n_sets}]", descriptor_layouts) layout_ci.setLayoutCount = n_sets layout_ci.pSetLayouts = set_layouts else: layout_ci.setLayoutCount = 0 if push_constant_size > 0: push_range = ffi.new("VkPushConstantRange*") push_range.stageFlags = vk.VK_SHADER_STAGE_COMPUTE_BIT push_range.offset = 0 push_range.size = push_constant_size layout_ci.pushConstantRangeCount = 1 layout_ci.pPushConstantRanges = push_range layout_out = ffi.new("VkPipelineLayout*") result = vk._vulkan._callApi( vk._vulkan.lib.vkCreatePipelineLayout, device, layout_ci, ffi.NULL, layout_out, ) if result != vk.VK_SUCCESS: raise RuntimeError(f"vkCreatePipelineLayout (compute) failed: {result}") pipeline_layout = layout_out[0] main_name = ffi.new("char[]", entry_point) stage = ffi.new("VkPipelineShaderStageCreateInfo*") stage.sType = vk.VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO stage.stage = vk.VK_SHADER_STAGE_COMPUTE_BIT stage.module = module stage.pName = main_name pipeline_ci = ffi.new("VkComputePipelineCreateInfo*") pipeline_ci.sType = vk.VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO pipeline_ci.layout = pipeline_layout pipeline_ci.stage = stage[0] pipeline_out = ffi.new("VkPipeline*") result = vk._vulkan._callApi( vk._vulkan.lib.vkCreateComputePipelines, device, ffi.NULL, 1, pipeline_ci, ffi.NULL, pipeline_out, ) if result != vk.VK_SUCCESS: raise RuntimeError(f"vkCreateComputePipelines failed: {result}") pipeline = pipeline_out[0] log.debug("Compute pipeline created: %s", shader_path.name) return pipeline, pipeline_layout, module