"""Vulkan timestamp query pool for per-pass GPU profiling.
A :class:`TimestampPool` wraps a single ``VkQueryPool`` of
``VK_QUERY_TYPE_TIMESTAMP`` and exposes a label-based API::
pool.reset(cmd)
pool.begin(cmd, "shadow")
# ... draw commands ...
pool.end(cmd, "shadow")
# submit, wait, then:
timings = pool.read_results() # {"shadow": gpu_ms, ...}
The pool itself is purely additive — constructing one allocates a
``VkQueryPool`` and nothing else; no calls are issued unless the higher
layer chooses to instrument a frame.
"""
from __future__ import annotations
import logging
from typing import Any
import vulkan as vk
__all__ = ["TimestampPool"]
log = logging.getLogger(__name__)
[docs]
class TimestampPool:
"""Per-pass GPU timestamp pool.
Each label consumes two consecutive query slots (begin + end). Strict
misuse — re-beginning a label, ending an unknown label, or exceeding
``max_labels`` distinct labels per frame — raises ``RuntimeError``.
"""
__slots__ = (
"device",
"physical_device",
"max_labels",
"ns_per_tick",
"pool",
"_labels",
"_next_index",
)
def __init__(self, device: Any, physical_device: Any, max_labels: int = 64) -> None:
if max_labels <= 0:
raise ValueError(f"max_labels must be positive, got {max_labels}")
self.device = device
self.physical_device = physical_device
self.max_labels = max_labels
props = vk.vkGetPhysicalDeviceProperties(physical_device)
period = float(props.limits.timestampPeriod)
if period <= 0.0:
raise RuntimeError("Physical device does not support timestamp queries (timestampPeriod == 0)")
self.ns_per_tick = period
create_info = vk.VkQueryPoolCreateInfo(
queryType=vk.VK_QUERY_TYPE_TIMESTAMP,
queryCount=max_labels * 2,
)
self.pool: Any = vk.vkCreateQueryPool(device, create_info, None)
# label -> (begin_index, end_index | None). end_index becomes set in end().
self._labels: dict[str, list[int | None]] = {}
self._next_index: int = 0
log.debug("TimestampPool created (max_labels=%d, ns/tick=%.3f)", max_labels, self.ns_per_tick)
# ------------------------------------------------------------------
# Per-frame command recording
# ------------------------------------------------------------------
[docs]
def reset(self, cmd: Any) -> None:
"""Reset the pool at the start of a frame.
Must be recorded into ``cmd`` before any ``begin``/``end`` calls.
Clears the host-side label tracking so labels can be re-issued.
"""
vk.vkCmdResetQueryPool(cmd, self.pool, 0, self.max_labels * 2)
self._labels.clear()
self._next_index = 0
[docs]
def begin(self, cmd: Any, label: str) -> None:
"""Record a TOP_OF_PIPE timestamp for ``label``."""
if label in self._labels:
raise RuntimeError(f"TimestampPool.begin: label {label!r} already begun this frame without end()")
if self._next_index >= self.max_labels * 2:
raise RuntimeError(
f"TimestampPool capacity exceeded: max_labels={self.max_labels}, "
f"cannot begin {label!r}"
)
idx = self._next_index
self._labels[label] = [idx, None]
self._next_index += 1
vk.vkCmdWriteTimestamp(cmd, vk.VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, self.pool, idx)
[docs]
def end(self, cmd: Any, label: str) -> None:
"""Record a BOTTOM_OF_PIPE timestamp matching a prior ``begin(label)``."""
entry = self._labels.get(label)
if entry is None:
raise RuntimeError(f"TimestampPool.end: label {label!r} was never begun this frame")
if entry[1] is not None:
raise RuntimeError(f"TimestampPool.end: label {label!r} already ended this frame")
if self._next_index >= self.max_labels * 2:
raise RuntimeError(
f"TimestampPool capacity exceeded: max_labels={self.max_labels}, "
f"cannot end {label!r}"
)
idx = self._next_index
entry[1] = idx
self._next_index += 1
vk.vkCmdWriteTimestamp(cmd, vk.VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, self.pool, idx)
# ------------------------------------------------------------------
# Result readback
# ------------------------------------------------------------------
[docs]
def read_results(self) -> dict[str, float]:
"""Read all begin/end pairs written this frame and return ms per label.
Uses ``VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT`` so the
call blocks until the GPU has written every queried slot. Labels
that have a begin but no matching end are skipped.
"""
if self._next_index == 0:
return {}
count = self._next_index
data = vk.ffi.new(f"uint64_t[{count}]")
stride = vk.ffi.sizeof("uint64_t")
size = stride * count
vk.vkGetQueryPoolResults(
self.device,
self.pool,
0,
count,
size,
data,
stride,
vk.VK_QUERY_RESULT_64_BIT | vk.VK_QUERY_RESULT_WAIT_BIT,
)
ns_per_tick = self.ns_per_tick
results: dict[str, float] = {}
for label, (begin_idx, end_idx) in self._labels.items():
if begin_idx is None or end_idx is None:
continue
ticks = data[end_idx] - data[begin_idx]
results[label] = (ticks * ns_per_tick) / 1_000_000.0
return results
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
[docs]
def destroy(self) -> None:
"""Destroy the underlying ``VkQueryPool``."""
if self.pool is not None:
vk.vkDestroyQueryPool(self.device, self.pool, None)
self.pool = None
self._labels.clear()
self._next_index = 0
[docs]
def __enter__(self) -> TimestampPool:
return self
[docs]
def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
self.destroy()