Source code for simvx.graphics.gpu.timestamp_pool

"""Vulkan timestamp query pool for per-pass GPU profiling.

A :class:`TimestampPool` wraps a single ``VkQueryPool`` of
``VK_QUERY_TYPE_TIMESTAMP`` and exposes a label-based API::

    pool.reset(cmd)
    pool.begin(cmd, "shadow")
    # ... draw commands ...
    pool.end(cmd, "shadow")
    # submit, wait, then:
    timings = pool.read_results()  # {"shadow": gpu_ms, ...}

The pool itself is purely additive — constructing one allocates a
``VkQueryPool`` and nothing else; no calls are issued unless the higher
layer chooses to instrument a frame.
"""

from __future__ import annotations

import logging
from typing import Any

import vulkan as vk

__all__ = ["TimestampPool"]

log = logging.getLogger(__name__)


[docs] class TimestampPool: """Per-pass GPU timestamp pool. Each label consumes two consecutive query slots (begin + end). Strict misuse — re-beginning a label, ending an unknown label, or exceeding ``max_labels`` distinct labels per frame — raises ``RuntimeError``. """ __slots__ = ( "device", "physical_device", "max_labels", "ns_per_tick", "pool", "_labels", "_next_index", ) def __init__(self, device: Any, physical_device: Any, max_labels: int = 64) -> None: if max_labels <= 0: raise ValueError(f"max_labels must be positive, got {max_labels}") self.device = device self.physical_device = physical_device self.max_labels = max_labels props = vk.vkGetPhysicalDeviceProperties(physical_device) period = float(props.limits.timestampPeriod) if period <= 0.0: raise RuntimeError("Physical device does not support timestamp queries (timestampPeriod == 0)") self.ns_per_tick = period create_info = vk.VkQueryPoolCreateInfo( queryType=vk.VK_QUERY_TYPE_TIMESTAMP, queryCount=max_labels * 2, ) self.pool: Any = vk.vkCreateQueryPool(device, create_info, None) # label -> (begin_index, end_index | None). end_index becomes set in end(). self._labels: dict[str, list[int | None]] = {} self._next_index: int = 0 log.debug("TimestampPool created (max_labels=%d, ns/tick=%.3f)", max_labels, self.ns_per_tick) # ------------------------------------------------------------------ # Per-frame command recording # ------------------------------------------------------------------
[docs] def reset(self, cmd: Any) -> None: """Reset the pool at the start of a frame. Must be recorded into ``cmd`` before any ``begin``/``end`` calls. Clears the host-side label tracking so labels can be re-issued. """ vk.vkCmdResetQueryPool(cmd, self.pool, 0, self.max_labels * 2) self._labels.clear() self._next_index = 0
[docs] def begin(self, cmd: Any, label: str) -> None: """Record a TOP_OF_PIPE timestamp for ``label``.""" if label in self._labels: raise RuntimeError(f"TimestampPool.begin: label {label!r} already begun this frame without end()") if self._next_index >= self.max_labels * 2: raise RuntimeError( f"TimestampPool capacity exceeded: max_labels={self.max_labels}, " f"cannot begin {label!r}" ) idx = self._next_index self._labels[label] = [idx, None] self._next_index += 1 vk.vkCmdWriteTimestamp(cmd, vk.VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, self.pool, idx)
[docs] def end(self, cmd: Any, label: str) -> None: """Record a BOTTOM_OF_PIPE timestamp matching a prior ``begin(label)``.""" entry = self._labels.get(label) if entry is None: raise RuntimeError(f"TimestampPool.end: label {label!r} was never begun this frame") if entry[1] is not None: raise RuntimeError(f"TimestampPool.end: label {label!r} already ended this frame") if self._next_index >= self.max_labels * 2: raise RuntimeError( f"TimestampPool capacity exceeded: max_labels={self.max_labels}, " f"cannot end {label!r}" ) idx = self._next_index entry[1] = idx self._next_index += 1 vk.vkCmdWriteTimestamp(cmd, vk.VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, self.pool, idx)
# ------------------------------------------------------------------ # Result readback # ------------------------------------------------------------------
[docs] def read_results(self) -> dict[str, float]: """Read all begin/end pairs written this frame and return ms per label. Uses ``VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT`` so the call blocks until the GPU has written every queried slot. Labels that have a begin but no matching end are skipped. """ if self._next_index == 0: return {} count = self._next_index data = vk.ffi.new(f"uint64_t[{count}]") stride = vk.ffi.sizeof("uint64_t") size = stride * count vk.vkGetQueryPoolResults( self.device, self.pool, 0, count, size, data, stride, vk.VK_QUERY_RESULT_64_BIT | vk.VK_QUERY_RESULT_WAIT_BIT, ) ns_per_tick = self.ns_per_tick results: dict[str, float] = {} for label, (begin_idx, end_idx) in self._labels.items(): if begin_idx is None or end_idx is None: continue ticks = data[end_idx] - data[begin_idx] results[label] = (ticks * ns_per_tick) / 1_000_000.0 return results
# ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------
[docs] def destroy(self) -> None: """Destroy the underlying ``VkQueryPool``.""" if self.pool is not None: vk.vkDestroyQueryPool(self.device, self.pool, None) self.pool = None self._labels.clear() self._next_index = 0
[docs] def __enter__(self) -> TimestampPool: return self
[docs] def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: self.destroy()