Source code for simvx.graphics.renderer.colour_grading

"""Colour grading pass — LUT-based colour correction via compute shader."""

import logging
from typing import Any

import numpy as np
import vulkan as vk

from ..gpu.descriptors import (
    DescriptorWriteBatch,
    allocate_descriptor_set,
    create_descriptor_set_layout,
    create_pool_for_types,
)
from ..gpu.memory import create_buffer
from ..gpu.pipeline_compute import create_compute_pipeline

__all__ = ["ColourGradingPass"]

log = logging.getLogger(__name__)

# Push constant: vec4 adjustments(16) + vec4 temperature(16) + vec4 resolution(16) = 48 bytes
_PC_SIZE = 48

[docs] def generate_neutral_lut(size: int = 32) -> np.ndarray: """Generate an identity 3D LUT (no colour change).""" lut = np.zeros((size, size, size, 4), dtype=np.uint8) for b in range(size): for g in range(size): for r in range(size): lut[b, g, r] = [ int(r / (size - 1) * 255), int(g / (size - 1) * 255), int(b / (size - 1) * 255), 255, ] return lut
[docs] def generate_warm_lut(size: int = 32) -> np.ndarray: """Generate a warm-toned 3D LUT (shifted toward orange/amber).""" lut = generate_neutral_lut(size) for b in range(size): for g in range(size): for r in range(size): rf = r / (size - 1) gf = g / (size - 1) bf = b / (size - 1) # Warm shift: boost reds, slight green reduction, reduce blues rf = min(1.0, rf * 1.1 + 0.02) gf = gf * 0.95 + 0.01 bf = bf * 0.8 lut[b, g, r] = [int(rf * 255), int(gf * 255), int(bf * 255), 255] return lut
[docs] def generate_cool_lut(size: int = 32) -> np.ndarray: """Generate a cool-toned 3D LUT (shifted toward blue/teal).""" lut = generate_neutral_lut(size) for b in range(size): for g in range(size): for r in range(size): rf = r / (size - 1) gf = g / (size - 1) bf = b / (size - 1) # Cool shift: reduce reds, boost greens slightly, boost blues rf = rf * 0.85 gf = min(1.0, gf * 1.02 + 0.01) bf = min(1.0, bf * 1.15 + 0.02) lut[b, g, r] = [int(rf * 255), int(gf * 255), int(bf * 255), 255] return lut
[docs] def generate_vintage_lut(size: int = 32) -> np.ndarray: """Generate a vintage/desaturated warm 3D LUT.""" lut = generate_neutral_lut(size) for b in range(size): for g in range(size): for r in range(size): rf = r / (size - 1) gf = g / (size - 1) bf = b / (size - 1) # Desaturate toward luminance luma = 0.2126 * rf + 0.7152 * gf + 0.0722 * bf sat = 0.6 # reduced saturation rf = luma + (rf - luma) * sat gf = luma + (gf - luma) * sat bf = luma + (bf - luma) * sat # Warm tint rf = min(1.0, rf * 1.05 + 0.03) gf = gf * 0.95 bf = bf * 0.75 # Lifted blacks (fade effect) rf = rf * 0.9 + 0.05 gf = gf * 0.9 + 0.04 bf = bf * 0.9 + 0.06 lut[b, g, r] = [ int(min(1.0, rf) * 255), int(min(1.0, gf) * 255), int(min(1.0, bf) * 255), 255, ] return lut
def _kelvin_to_rgb_multipliers(kelvin: float) -> tuple[float, float]: """Approximate colour temperature as R and B multipliers (G stays at 1.0). Returns (r_mult, b_mult) relative to 6500K neutral. Uses simplified Planckian locus approximation. """ # Normalize around 6500K t = (kelvin - 6500.0) / 6500.0 if t > 0: # Warmer: boost red, reduce blue r = 1.0 + t * 0.15 b = 1.0 - t * 0.25 else: # Cooler: reduce red, boost blue r = 1.0 + t * 0.25 b = 1.0 - t * 0.15 return (max(0.5, min(1.5, r)), max(0.5, min(1.5, b)))
[docs] class ColourGradingPass: """Compute-based colour grading: LUT lookup + brightness/contrast/saturation/temperature. Operates in-place on the HDR colour image. Apply after fog, before tone mapping. """ def __init__(self, engine: Any): self._engine = engine self._ready = False # Compute pipeline self._pipeline: Any = None self._layout: Any = None self._module: Any = None # Descriptors self._desc_pool: Any = None self._desc_layout: Any = None self._desc_set: Any = None # LUT 3D texture self._lut_image: Any = None self._lut_memory: Any = None self._lut_view: Any = None self._lut_sampler: Any = None self._lut_size: int = 0 # Dimensions self._width: int = 0 self._height: int = 0 # Public settings self.enabled: bool = False self.brightness: float = 0.0 # -1 to 1 self.contrast: float = 1.0 # 0 to 2 self.saturation: float = 1.0 # 0 to 2 self.colour_temperature: float = 6500.0 # Kelvin, 1000-12000
[docs] def setup(self, width: int, height: int, colour_view: Any) -> None: """Initialize colour grading pipeline and upload default neutral LUT.""" self._width = width self._height = height self._create_lut_sampler() self._upload_lut(generate_neutral_lut(32)) self._create_descriptors(colour_view) self._create_pipeline() self._ready = True log.debug("Colour grading pass initialized (%dx%d)", width, height)
[docs] def set_lut(self, lut_data: np.ndarray) -> None: """Upload a new 3D LUT texture. Shape should be (size, size, size, 4) uint8.""" if not self._ready: return # Destroy old LUT device = self._engine.ctx.device if self._lut_view: vk.vkDestroyImageView(device, self._lut_view, None) if self._lut_image: vk.vkDestroyImage(device, self._lut_image, None) if self._lut_memory: vk.vkFreeMemory(device, self._lut_memory, None) self._upload_lut(lut_data) # Re-write LUT descriptor (binding 1) self._update_lut_descriptor()
def _create_lut_sampler(self) -> None: """Create trilinear sampler for 3D LUT texture.""" self._lut_sampler = vk.vkCreateSampler( self._engine.ctx.device, vk.VkSamplerCreateInfo( magFilter=vk.VK_FILTER_LINEAR, minFilter=vk.VK_FILTER_LINEAR, addressModeU=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE, addressModeV=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE, addressModeW=vk.VK_SAMPLER_ADDRESS_MODE_CLAMP_TO_EDGE, anisotropyEnable=vk.VK_FALSE, unnormalizedCoordinates=vk.VK_FALSE, mipmapMode=vk.VK_SAMPLER_MIPMAP_MODE_LINEAR, ), None, ) def _upload_lut(self, lut_data: np.ndarray) -> None: """Upload a 3D LUT as a VK_IMAGE_TYPE_3D texture.""" e = self._engine device = e.ctx.device size = lut_data.shape[0] self._lut_size = size pixel_data = np.ascontiguousarray(lut_data) # Staging buffer staging_buf, staging_mem = _create_staging_buffer(device, e.ctx.physical_device, pixel_data) # Create 3D image img_ci = vk.VkImageCreateInfo( imageType=vk.VK_IMAGE_TYPE_3D, format=vk.VK_FORMAT_R8G8B8A8_UNORM, extent=vk.VkExtent3D(width=size, height=size, depth=size), mipLevels=1, arrayLayers=1, samples=vk.VK_SAMPLE_COUNT_1_BIT, tiling=vk.VK_IMAGE_TILING_OPTIMAL, usage=vk.VK_IMAGE_USAGE_TRANSFER_DST_BIT | vk.VK_IMAGE_USAGE_SAMPLED_BIT, sharingMode=vk.VK_SHARING_MODE_EXCLUSIVE, initialLayout=vk.VK_IMAGE_LAYOUT_UNDEFINED, ) self._lut_image = vk.vkCreateImage(device, img_ci, None) mem_reqs = vk.vkGetImageMemoryRequirements(device, self._lut_image) mem_props = vk.vkGetPhysicalDeviceMemoryProperties(e.ctx.physical_device) mem_type = _find_memory_type(mem_props, mem_reqs.memoryTypeBits, vk.VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT) self._lut_memory = vk.vkAllocateMemory( device, vk.VkMemoryAllocateInfo( allocationSize=mem_reqs.size, memoryTypeIndex=mem_type, ), None, ) vk.vkBindImageMemory(device, self._lut_image, self._lut_memory, 0) # Transition UNDEFINED -> TRANSFER_DST from ..gpu.memory import begin_single_time_commands, end_single_time_commands cmd = begin_single_time_commands(device, e.ctx.command_pool) barrier = vk.VkImageMemoryBarrier( srcAccessMask=0, dstAccessMask=vk.VK_ACCESS_TRANSFER_WRITE_BIT, oldLayout=vk.VK_IMAGE_LAYOUT_UNDEFINED, newLayout=vk.VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, srcQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED, dstQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED, image=self._lut_image, subresourceRange=vk.VkImageSubresourceRange( aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT, baseMipLevel=0, levelCount=1, baseArrayLayer=0, layerCount=1, ), ) vk.vkCmdPipelineBarrier( cmd, vk.VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, vk.VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 0, None, 0, None, 1, [barrier], ) # Copy staging -> image region = vk.VkBufferImageCopy( bufferOffset=0, bufferRowLength=0, bufferImageHeight=0, imageSubresource=vk.VkImageSubresourceLayers( aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT, mipLevel=0, baseArrayLayer=0, layerCount=1, ), imageOffset=vk.VkOffset3D(x=0, y=0, z=0), imageExtent=vk.VkExtent3D(width=size, height=size, depth=size), ) vk.vkCmdCopyBufferToImage( cmd, staging_buf, self._lut_image, vk.VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, [region], ) # Transition TRANSFER_DST -> SHADER_READ_ONLY barrier2 = vk.VkImageMemoryBarrier( srcAccessMask=vk.VK_ACCESS_TRANSFER_WRITE_BIT, dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT, oldLayout=vk.VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, newLayout=vk.VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL, srcQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED, dstQueueFamilyIndex=vk.VK_QUEUE_FAMILY_IGNORED, image=self._lut_image, subresourceRange=vk.VkImageSubresourceRange( aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT, baseMipLevel=0, levelCount=1, baseArrayLayer=0, layerCount=1, ), ) vk.vkCmdPipelineBarrier( cmd, vk.VK_PIPELINE_STAGE_TRANSFER_BIT, vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, 0, 0, None, 0, None, 1, [barrier2], ) end_single_time_commands(device, e.ctx.graphics_queue, e.ctx.command_pool, cmd) # Cleanup staging vk.vkDestroyBuffer(device, staging_buf, None) vk.vkFreeMemory(device, staging_mem, None) # Image view (3D) self._lut_view = vk.vkCreateImageView( device, vk.VkImageViewCreateInfo( image=self._lut_image, viewType=vk.VK_IMAGE_VIEW_TYPE_3D, format=vk.VK_FORMAT_R8G8B8A8_UNORM, subresourceRange=vk.VkImageSubresourceRange( aspectMask=vk.VK_IMAGE_ASPECT_COLOR_BIT, baseMipLevel=0, levelCount=1, baseArrayLayer=0, layerCount=1, ), ), None, ) def _create_descriptors(self, colour_view: Any) -> None: """Create descriptor set: colour storage image (binding 0) + LUT sampler (binding 1).""" device = self._engine.ctx.device cs = vk.VK_SHADER_STAGE_COMPUTE_BIT self._desc_layout = create_descriptor_set_layout(device, [ (0, vk.VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, cs, 1), (1, vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER, cs, 1), ]) self._desc_pool = create_pool_for_types(device, { vk.VK_DESCRIPTOR_TYPE_STORAGE_IMAGE: 1, vk.VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER: 1, }) self._desc_set = allocate_descriptor_set(device, self._desc_pool, self._desc_layout) self._write_descriptors(colour_view) def _write_descriptors(self, colour_view: Any) -> None: """Write colour storage image and LUT sampler to descriptor set.""" with DescriptorWriteBatch(self._engine.ctx.device) as batch: batch.storage_image(self._desc_set, 0, colour_view) batch.image(self._desc_set, 1, self._lut_view, self._lut_sampler) def _update_lut_descriptor(self) -> None: """Update LUT descriptor (binding 1) after LUT replacement.""" with DescriptorWriteBatch(self._engine.ctx.device) as batch: batch.image(self._desc_set, 1, self._lut_view, self._lut_sampler) def _create_pipeline(self) -> None: """Create colour grading compute pipeline.""" e = self._engine self._pipeline, self._layout, self._module = create_compute_pipeline( e.ctx.device, e.shader_dir / "colour_grade.comp", [self._desc_layout], _PC_SIZE, )
[docs] def render(self, cmd: Any) -> None: """Dispatch colour grading compute shader. Call after fog, before tonemap. Args: cmd: Active command buffer (outside any render pass). """ if not self._ready or not self.enabled: return ffi = vk.ffi groups_x = (self._width + 7) // 8 groups_y = (self._height + 7) // 8 vk.vkCmdBindPipeline(cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._pipeline) vk.vkCmdBindDescriptorSets( cmd, vk.VK_PIPELINE_BIND_POINT_COMPUTE, self._layout, 0, 1, [self._desc_set], 0, None, ) # Compute temperature multipliers r_mult, b_mult = _kelvin_to_rgb_multipliers(self.colour_temperature) # Pack push constants: 3 * vec4 = 48 bytes adjustments = np.array( [ self.brightness, self.contrast, self.saturation, float(self._lut_size), ], dtype=np.float32, ) temperature = np.array([r_mult, b_mult, 0.0, 0.0], dtype=np.float32) resolution = np.array( [ float(self._width), float(self._height), 1.0 / self._width, 1.0 / self._height, ], dtype=np.float32, ) pc_data = adjustments.tobytes() + temperature.tobytes() + resolution.tobytes() cbuf = ffi.new("char[]", pc_data) vk._vulkan.lib.vkCmdPushConstants( cmd, self._layout, vk.VK_SHADER_STAGE_COMPUTE_BIT, 0, _PC_SIZE, cbuf, ) vk.vkCmdDispatch(cmd, groups_x, groups_y, 1) # Barrier: colour grading write -> next pass read barrier = vk.VkMemoryBarrier( srcAccessMask=vk.VK_ACCESS_SHADER_WRITE_BIT, dstAccessMask=vk.VK_ACCESS_SHADER_READ_BIT | vk.VK_ACCESS_SHADER_WRITE_BIT, ) vk.vkCmdPipelineBarrier( cmd, vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT, vk.VK_PIPELINE_STAGE_COMPUTE_SHADER_BIT | vk.VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0, 1, [barrier], 0, None, 0, None, )
[docs] def resize(self, width: int, height: int, colour_view: Any) -> None: """Update descriptors for new dimensions.""" if not self._ready: return self._width = width self._height = height self._write_descriptors(colour_view)
[docs] def cleanup(self) -> None: """Release all GPU resources.""" if not self._ready: return device = self._engine.ctx.device if self._pipeline: vk.vkDestroyPipeline(device, self._pipeline, None) if self._layout: vk.vkDestroyPipelineLayout(device, self._layout, None) if self._module: vk.vkDestroyShaderModule(device, self._module, None) if self._desc_pool: vk.vkDestroyDescriptorPool(device, self._desc_pool, None) if self._desc_layout: vk.vkDestroyDescriptorSetLayout(device, self._desc_layout, None) if self._lut_view: vk.vkDestroyImageView(device, self._lut_view, None) if self._lut_image: vk.vkDestroyImage(device, self._lut_image, None) if self._lut_memory: vk.vkFreeMemory(device, self._lut_memory, None) if self._lut_sampler: vk.vkDestroySampler(device, self._lut_sampler, None) self._ready = False
def _find_memory_type(mem_props: Any, type_filter: int, properties: int) -> int: """Find a suitable memory type index.""" for i in range(mem_props.memoryTypeCount): if (type_filter & (1 << i)) and (mem_props.memoryTypes[i].propertyFlags & properties) == properties: return i raise RuntimeError("Failed to find suitable memory type") def _create_staging_buffer(device: Any, physical_device: Any, data: np.ndarray) -> tuple[Any, Any]: """Create and fill a host-visible staging buffer from numpy data.""" ffi = vk.ffi size = data.nbytes buf, mem = create_buffer( device, physical_device, size, vk.VK_BUFFER_USAGE_TRANSFER_SRC_BIT, vk.VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | vk.VK_MEMORY_PROPERTY_HOST_COHERENT_BIT, ) dst = vk.vkMapMemory(device, mem, 0, size, 0) ffi.memmove(dst, ffi.cast("void*", data.ctypes.data), size) vk.vkUnmapMemory(device, mem) return buf, mem