Source code for simvx.graphics.renderer.light_cull_pass

"""Forward+ tiled light culling via compute shader.

Dispatches a compute shader that assigns lights to 16x16 screen-space tiles
using depth-aware frustum culling.  The output buffers (light index list and
per-tile offset/count) are bound to the fragment shader so it can loop over
only the lights relevant to each pixel's tile.
"""

import logging
import math
from typing import Any

import numpy as np
import vulkan as vk

from ..gpu.descriptors import (
    DescriptorWriteBatch,
    allocate_descriptor_set,
    create_descriptor_set_layout,
    create_pool_for_types,
)
from ..gpu.memory import create_buffer, upload_numpy
from ..gpu.pipeline_compute import create_compute_pipeline
from .pass_helpers import create_nearest_sampler
from .render_pass import FrameContext, RenderPass

__all__ = ["LightCullPass"]

log = logging.getLogger(__name__)

_TILE_SIZE = 16
# CullParams SSBO layout (must match light_cull.comp binding 0):
#   uint  grid_dims_x   (4)
#   uint  grid_dims_y   (4)
#   uint  light_count   (4)
#   float near_plane    (4)
#   float far_plane     (4)
#   uint  _pad[3]       (12)
#   uint  global_index  (4)    — atomic counter
# Total = 36 bytes, padded to 48 for alignment
_CULL_PARAMS_SIZE = 48
_VK_WHOLE_SIZE_U64 = 0xFFFFFFFFFFFFFFFF

[docs] class LightCullPass(RenderPass): """GPU tiled light culling for Forward+ rendering. Creates a compute pipeline that reads a depth texture and light SSBO, then outputs per-tile light index lists consumed by the fragment shader. """ name = "light_cull" stage = "pre_render" inputs = () outputs = ("light_tiles",) def __init__(self, engine: Any): super().__init__() self._engine = engine self._ready = False # Pipeline self._pipeline: Any = None self._layout: Any = None self._module: Any = None # Descriptors self._desc_pool: Any = None self._desc_layout: Any = None self._desc_set: Any = None self._depth_sampler: Any = None # Buffers self._cull_params_buf: Any = None self._cull_params_mem: Any = None self._light_index_buf: Any = None self._light_index_mem: Any = None self._tile_buf: Any = None self._tile_mem: Any = None # Dimensions self._grid_x: int = 0 self._grid_y: int = 0 self._max_tiles: int = 0 self._max_light_indices: int = 0
[docs] def setup(self, width: int, height: int, max_lights: int = 256) -> None: """Create compute pipeline, SSBOs, and descriptor set. Args: width: Viewport width in pixels. height: Viewport height in pixels. max_lights: Maximum number of lights in the scene. """ e = self._engine device = e._device phys = e._physical_device self._grid_x = math.ceil(width / _TILE_SIZE) self._grid_y = math.ceil(height / _TILE_SIZE) self._max_tiles = self._grid_x * self._grid_y # Worst case: every tile has every light self._max_light_indices = self._max_tiles * min(max_lights, 256) # CullParams SSBO (binding 0) — includes atomic counter self._cull_params_buf, self._cull_params_mem = create_buffer( device, phys, _CULL_PARAMS_SIZE, vk.VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, ) # Light index SSBO (binding 3) — output index_buf_size = max(self._max_light_indices * 4, 4) # uint32 per entry self._light_index_buf, self._light_index_mem = create_buffer( device, phys, index_buf_size, vk.VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, vk.VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, ) # Tile SSBO (binding 4) — output: uvec2 per tile tile_buf_size = max(self._max_tiles * 8, 8) # uvec2 = 8 bytes self._tile_buf, self._tile_mem = create_buffer( device, phys, tile_buf_size, vk.VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, vk.VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT, ) self._depth_sampler = create_nearest_sampler(device) # Descriptor set layout: 5 bindings # 0: CullParams SSBO (compute, read+atomic) # 1: depth texture sampler (compute) # 2: LightBuffer SSBO (compute, read-only) # 3: LightIndexBuffer SSBO (compute, write) # 4: TileBuffer SSBO (compute, write) cs = vk.VK_SHADER_STAGE_COMPUTE_BIT SB = vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER self._desc_layout = create_descriptor_set_layout(device, [ (0, SB, cs, 1), (1, vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, cs, 1), (2, SB, cs, 1), (3, SB, cs, 1), (4, SB, cs, 1), ]) self._desc_pool = create_pool_for_types(device, { SB: 4, vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER: 1, }) self._desc_set = allocate_descriptor_set(device, self._desc_pool, self._desc_layout) # Write static buffer descriptors (bindings 0, 3, 4); depth + light SSBO # land later via update_descriptors() once the depth view exists. with DescriptorWriteBatch(device) as batch: batch.ssbo(self._desc_set, 0, self._cull_params_buf, _CULL_PARAMS_SIZE) batch.ssbo(self._desc_set, 3, self._light_index_buf, index_buf_size) batch.ssbo(self._desc_set, 4, self._tile_buf, tile_buf_size) # Compile compute shader and create pipeline (push constants: 2 mat4 = 128 bytes) self._pipeline, self._layout, self._module = create_compute_pipeline( device, e.shader_dir / "light_cull.comp", [self._desc_layout], 128, ) self._ready = True log.debug("Light cull pass initialised (%dx%d tiles)", self._grid_x, self._grid_y)
[docs] def update_descriptors(self, depth_view: Any, light_ssbo: Any, light_ssbo_size: int) -> None: """Update depth texture (binding 1) and light SSBO (binding 2).""" with DescriptorWriteBatch(self._engine._device) as batch: batch.image(self._desc_set, 1, depth_view, self._depth_sampler) batch.ssbo(self._desc_set, 2, light_ssbo, light_ssbo_size)
[docs] def record(self, cmd: Any, frame: FrameContext) -> None: """RenderPass interface — delegates to _dispatch_impl with frame data.""" if self._ready: self._dispatch_impl(cmd, frame.camera_view, frame.camera_proj, frame.light_count, frame.near, frame.far)
def _dispatch_impl( self, cmd: Any, view_mat: np.ndarray, proj_mat: np.ndarray, light_count: int, near: float, far: float, ) -> None: """Record the light culling compute dispatch into a command buffer. Must be called outside a render pass, after depth is available and before the main geometry pass that reads the tile data. """ if not self._ready or light_count == 0: return device = self._engine._device # Upload cull parameters (includes resetting the atomic counter to 0) params = np.zeros(12, dtype=np.uint32) # 48 bytes = 12 uint32 params[0] = self._grid_x params[1] = self._grid_y params[2] = light_count params.view(np.float32)[3] = near params.view(np.float32)[4] = far # params[5..7] = padding, params[8] = atomic counter = 0 upload_numpy(device, self._cull_params_mem, params) # Memory barrier: ensure cull params upload is visible to compute host_barrier = vk.VkMemoryBarrier( srcAccessMask=vk.VK_ACCESS_HOST_WRITE_BIT, dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT | vk.VK_ACCESS_SHADER_WRITE_BIT, ) vk.vkCmdPipelineBarrier( cmd, vk.VK_PIPELINE_STAGE_HOST_BIT, vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 1, [host_barrier], 0, None, 0, None, ) # Bind pipeline and descriptors vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._pipeline) vk.vkCmdBindDescriptorSets( cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._layout, 0, 1, [self._desc_set], 0, None, ) # Push constants: view + proj (transposed for column-major GLSL) ffi = vk.ffi view_t = np.ascontiguousarray(view_mat.T, dtype=np.float32) proj_t = np.ascontiguousarray(proj_mat.T, dtype=np.float32) pc_bytes = view_t.tobytes() + proj_t.tobytes() cbuf = ffi.new("char[]", pc_bytes) vk._vulkan.lib.vkCmdPushConstants( cmd, self._layout, vk.VK_SHADER_STAGE_COMPUTE_BIT, 0, len(pc_bytes), cbuf, ) # Dispatch: one workgroup per tile vk.vkCmdDispatch(cmd, self._grid_x, self._grid_y, 1) # Barrier: compute writes → fragment reads compute_barrier = vk.VkMemoryBarrier( srcAccessMask=vk.VK_ACCESS_SHADER_WRITE_BIT, dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT, ) vk.vkCmdPipelineBarrier( cmd, vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, vk.VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0, 1, [compute_barrier], 0, None, 0, None, )
[docs] @property def ready(self) -> bool: return self._ready
[docs] @property def tile_buffer(self) -> Any: """Tile SSBO handle for binding in fragment shader descriptors.""" return self._tile_buf
[docs] @property def tile_buffer_size(self) -> int: return self._max_tiles * 8
[docs] @property def light_index_buffer(self) -> Any: """Light index SSBO handle for binding in fragment shader descriptors.""" return self._light_index_buf
[docs] @property def light_index_buffer_size(self) -> int: return self._max_light_indices * 4
[docs] @property def grid_dims(self) -> tuple[int, int]: return (self._grid_x, self._grid_y)
[docs] def resize(self, width: int, height: int) -> None: """Handle viewport resize — recreate tile/index buffers if grid changed.""" new_gx = math.ceil(width / _TILE_SIZE) new_gy = math.ceil(height / _TILE_SIZE) if new_gx == self._grid_x and new_gy == self._grid_y: return # Tear down and recreate with new dimensions self.destroy() self.setup(width, height)
[docs] def destroy(self) -> None: """Destroy all GPU resources.""" if not self._ready: return device = self._engine._device for obj, fn in [ (self._pipeline, vk.vkDestroyPipeline), (self._layout, vk.vkDestroyPipelineLayout), (self._module, vk.vkDestroyShaderModule), (self._desc_layout, vk.vkDestroyDescriptorSetLayout), (self._desc_pool, vk.vkDestroyDescriptorPool), (self._depth_sampler, vk.vkDestroySampler), ]: if obj: fn(device, obj, None) for buf, mem in [ (self._cull_params_buf, self._cull_params_mem), (self._light_index_buf, self._light_index_mem), (self._tile_buf, self._tile_mem), ]: if buf: vk.vkDestroyBuffer(device, buf, None) if mem: vk.vkFreeMemory(device, mem, None) self._ready = False log.debug("Light cull pass resources cleaned up")