"""Forward+ tiled light culling via compute shader.
Dispatches a compute shader that assigns lights to 16x16 screen-space tiles
using depth-aware frustum culling. The output buffers (light index list and
per-tile offset/count) are bound to the fragment shader so it can loop over
only the lights relevant to each pixel's tile.
"""
import logging
import math
from typing import Any
import numpy as np
import vulkan as vk
from ..gpu.descriptors import (
DescriptorWriteBatch,
allocate_descriptor_set,
create_descriptor_set_layout,
create_pool_for_types,
)
from ..gpu.memory import create_buffer, upload_numpy
from ..gpu.pipeline_compute import create_compute_pipeline
from .pass_helpers import create_nearest_sampler
from .render_pass import FrameContext, RenderPass
__all__ = ["LightCullPass"]
log = logging.getLogger(__name__)
_TILE_SIZE = 16
# CullParams SSBO layout (must match light_cull.comp binding 0):
# uint grid_dims_x (4)
# uint grid_dims_y (4)
# uint light_count (4)
# float near_plane (4)
# float far_plane (4)
# uint _pad[3] (12)
# uint global_index (4) — atomic counter
# Total = 36 bytes, padded to 48 for alignment
_CULL_PARAMS_SIZE = 48
_VK_WHOLE_SIZE_U64 = 0xFFFFFFFFFFFFFFFF
[docs]
class LightCullPass(RenderPass):
"""GPU tiled light culling for Forward+ rendering.
Creates a compute pipeline that reads a depth texture and light SSBO,
then outputs per-tile light index lists consumed by the fragment shader.
"""
name = "light_cull"
stage = "pre_render"
inputs = ()
outputs = ("light_tiles",)
def __init__(self, engine: Any):
super().__init__()
self._engine = engine
self._ready = False
# Pipeline
self._pipeline: Any = None
self._layout: Any = None
self._module: Any = None
# Descriptors
self._desc_pool: Any = None
self._desc_layout: Any = None
self._desc_set: Any = None
self._depth_sampler: Any = None
# Buffers
self._cull_params_buf: Any = None
self._cull_params_mem: Any = None
self._light_index_buf: Any = None
self._light_index_mem: Any = None
self._tile_buf: Any = None
self._tile_mem: Any = None
# Dimensions
self._grid_x: int = 0
self._grid_y: int = 0
self._max_tiles: int = 0
self._max_light_indices: int = 0
[docs]
def setup(self, width: int, height: int, max_lights: int = 256) -> None:
"""Create compute pipeline, SSBOs, and descriptor set.
Args:
width: Viewport width in pixels.
height: Viewport height in pixels.
max_lights: Maximum number of lights in the scene.
"""
e = self._engine
device = e._device
phys = e._physical_device
self._grid_x = math.ceil(width / _TILE_SIZE)
self._grid_y = math.ceil(height / _TILE_SIZE)
self._max_tiles = self._grid_x * self._grid_y
# Worst case: every tile has every light
self._max_light_indices = self._max_tiles * min(max_lights, 256)
# CullParams SSBO (binding 0) — includes atomic counter
self._cull_params_buf, self._cull_params_mem = create_buffer(
device, phys, _CULL_PARAMS_SIZE,
vk.VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
)
# Light index SSBO (binding 3) — output
index_buf_size = max(self._max_light_indices * 4, 4) # uint32 per entry
self._light_index_buf, self._light_index_mem = create_buffer(
device, phys, index_buf_size,
vk.VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT,
)
# Tile SSBO (binding 4) — output: uvec2 per tile
tile_buf_size = max(self._max_tiles * 8, 8) # uvec2 = 8 bytes
self._tile_buf, self._tile_mem = create_buffer(
device, phys, tile_buf_size,
vk.VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
vk.VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT,
)
self._depth_sampler = create_nearest_sampler(device)
# Descriptor set layout: 5 bindings
# 0: CullParams SSBO (compute, read+atomic)
# 1: depth texture sampler (compute)
# 2: LightBuffer SSBO (compute, read-only)
# 3: LightIndexBuffer SSBO (compute, write)
# 4: TileBuffer SSBO (compute, write)
cs = vk.VK_SHADER_STAGE_COMPUTE_BIT
SB = vk.VK_DESCRIPTOR_TYPE_STORAGE_BUFFER
self._desc_layout = create_descriptor_set_layout(device, [
(0, SB, cs, 1),
(1, vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, cs, 1),
(2, SB, cs, 1),
(3, SB, cs, 1),
(4, SB, cs, 1),
])
self._desc_pool = create_pool_for_types(device, {
SB: 4,
vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER: 1,
})
self._desc_set = allocate_descriptor_set(device, self._desc_pool, self._desc_layout)
# Write static buffer descriptors (bindings 0, 3, 4); depth + light SSBO
# land later via update_descriptors() once the depth view exists.
with DescriptorWriteBatch(device) as batch:
batch.ssbo(self._desc_set, 0, self._cull_params_buf, _CULL_PARAMS_SIZE)
batch.ssbo(self._desc_set, 3, self._light_index_buf, index_buf_size)
batch.ssbo(self._desc_set, 4, self._tile_buf, tile_buf_size)
# Compile compute shader and create pipeline (push constants: 2 mat4 = 128 bytes)
self._pipeline, self._layout, self._module = create_compute_pipeline(
device, e.shader_dir / "light_cull.comp", [self._desc_layout], 128,
)
self._ready = True
log.debug("Light cull pass initialised (%dx%d tiles)", self._grid_x, self._grid_y)
[docs]
def update_descriptors(self, depth_view: Any, light_ssbo: Any, light_ssbo_size: int) -> None:
"""Update depth texture (binding 1) and light SSBO (binding 2)."""
with DescriptorWriteBatch(self._engine._device) as batch:
batch.image(self._desc_set, 1, depth_view, self._depth_sampler)
batch.ssbo(self._desc_set, 2, light_ssbo, light_ssbo_size)
[docs]
def record(self, cmd: Any, frame: FrameContext) -> None:
"""RenderPass interface — delegates to _dispatch_impl with frame data."""
if self._ready:
self._dispatch_impl(cmd, frame.camera_view, frame.camera_proj, frame.light_count, frame.near, frame.far)
def _dispatch_impl(
self,
cmd: Any,
view_mat: np.ndarray,
proj_mat: np.ndarray,
light_count: int,
near: float,
far: float,
) -> None:
"""Record the light culling compute dispatch into a command buffer.
Must be called outside a render pass, after depth is available and
before the main geometry pass that reads the tile data.
"""
if not self._ready or light_count == 0:
return
device = self._engine._device
# Upload cull parameters (includes resetting the atomic counter to 0)
params = np.zeros(12, dtype=np.uint32) # 48 bytes = 12 uint32
params[0] = self._grid_x
params[1] = self._grid_y
params[2] = light_count
params.view(np.float32)[3] = near
params.view(np.float32)[4] = far
# params[5..7] = padding, params[8] = atomic counter = 0
upload_numpy(device, self._cull_params_mem, params)
# Memory barrier: ensure cull params upload is visible to compute
host_barrier = vk.VkMemoryBarrier(
srcAccessMask=vk.VK_ACCESS_HOST_WRITE_BIT,
dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT | vk.VK_ACCESS_SHADER_WRITE_BIT,
)
vk.vkCmdPipelineBarrier(
cmd,
vk.VK_PIPELINE_STAGE_HOST_BIT,
vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
0, 1, [host_barrier], 0, None, 0, None,
)
# Bind pipeline and descriptors
vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._pipeline)
vk.vkCmdBindDescriptorSets(
cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._layout,
0, 1, [self._desc_set], 0, None,
)
# Push constants: view + proj (transposed for column-major GLSL)
ffi = vk.ffi
view_t = np.ascontiguousarray(view_mat.T, dtype=np.float32)
proj_t = np.ascontiguousarray(proj_mat.T, dtype=np.float32)
pc_bytes = view_t.tobytes() + proj_t.tobytes()
cbuf = ffi.new("char[]", pc_bytes)
vk._vulkan.lib.vkCmdPushConstants(
cmd, self._layout, vk.VK_SHADER_STAGE_COMPUTE_BIT,
0, len(pc_bytes), cbuf,
)
# Dispatch: one workgroup per tile
vk.vkCmdDispatch(cmd, self._grid_x, self._grid_y, 1)
# Barrier: compute writes → fragment reads
compute_barrier = vk.VkMemoryBarrier(
srcAccessMask=vk.VK_ACCESS_SHADER_WRITE_BIT,
dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT,
)
vk.vkCmdPipelineBarrier(
cmd,
vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT,
vk.VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT,
0, 1, [compute_barrier], 0, None, 0, None,
)
[docs]
@property
def ready(self) -> bool:
return self._ready
[docs]
@property
def tile_buffer(self) -> Any:
"""Tile SSBO handle for binding in fragment shader descriptors."""
return self._tile_buf
[docs]
@property
def tile_buffer_size(self) -> int:
return self._max_tiles * 8
[docs]
@property
def light_index_buffer(self) -> Any:
"""Light index SSBO handle for binding in fragment shader descriptors."""
return self._light_index_buf
[docs]
@property
def light_index_buffer_size(self) -> int:
return self._max_light_indices * 4
[docs]
@property
def grid_dims(self) -> tuple[int, int]:
return (self._grid_x, self._grid_y)
[docs]
def resize(self, width: int, height: int) -> None:
"""Handle viewport resize — recreate tile/index buffers if grid changed."""
new_gx = math.ceil(width / _TILE_SIZE)
new_gy = math.ceil(height / _TILE_SIZE)
if new_gx == self._grid_x and new_gy == self._grid_y:
return
# Tear down and recreate with new dimensions
self.destroy()
self.setup(width, height)
[docs]
def destroy(self) -> None:
"""Destroy all GPU resources."""
if not self._ready:
return
device = self._engine._device
for obj, fn in [
(self._pipeline, vk.vkDestroyPipeline),
(self._layout, vk.vkDestroyPipelineLayout),
(self._module, vk.vkDestroyShaderModule),
(self._desc_layout, vk.vkDestroyDescriptorSetLayout),
(self._desc_pool, vk.vkDestroyDescriptorPool),
(self._depth_sampler, vk.vkDestroySampler),
]:
if obj:
fn(device, obj, None)
for buf, mem in [
(self._cull_params_buf, self._cull_params_mem),
(self._light_index_buf, self._light_index_mem),
(self._tile_buf, self._tile_mem),
]:
if buf:
vk.vkDestroyBuffer(device, buf, None)
if mem:
vk.vkFreeMemory(device, mem, None)
self._ready = False
log.debug("Light cull pass resources cleaned up")