"""Backend-agnostic shadow math.
Computes directional-light cascaded-shadow-map (CSM) split distances and
per-cascade light view-projection matrices. Pure NumPy — no Vulkan, no
WebGPU — so both desktop renderer and web streaming serializer share the
same implementation.
Matrices follow SimVX's row-major convention on the Python side. The
cascade VPs are returned in row-major layout; callers that hand them to a
column-major GPU API (Vulkan, WebGPU) must transpose at the boundary.
"""
from __future__ import annotations
import numpy as np
__all__ = [
"DEFAULT_CASCADE_COUNT",
"DEFAULT_LAMBDA",
"compute_cascades",
"compute_splits",
"depth_to_ndc",
]
# Default tuning — matches desktop ShadowPass.
DEFAULT_CASCADE_COUNT = 3
DEFAULT_LAMBDA = 0.5 # practical-split blend: log × λ + linear × (1-λ)
DEFAULT_PADDING = 0.1 # bounding-box pad as fraction of cascade extent
[docs]
def depth_to_ndc(z_eye: float, proj: np.ndarray) -> float:
"""Convert eye-space depth to Vulkan/WebGPU NDC z [0, 1].
``proj`` is a row-major 4×4 perspective matrix. Uses the standard
Vulkan convention (clip_z = proj[2][2]·(-z_eye) + proj[2][3], clip_w =
z_eye).
"""
clip_z = proj[2, 2] * (-z_eye) + proj[2, 3]
clip_w = z_eye
return clip_z / clip_w
[docs]
def compute_splits(
near: float,
far: float,
cascade_count: int = DEFAULT_CASCADE_COUNT,
lambda_split: float = DEFAULT_LAMBDA,
) -> np.ndarray:
"""Return ``cascade_count + 1`` split distances in eye space.
Uses the practical split scheme — blends logarithmic and linear splits
by ``lambda_split``. Matches Unity URP / UE / Godot tuning.
"""
splits = np.zeros(cascade_count + 1, dtype=np.float32)
splits[0] = near
for i in range(cascade_count):
p = (i + 1) / cascade_count
log_split = near * (far / near) ** p
lin_split = near + (far - near) * p
splits[i + 1] = lambda_split * log_split + (1.0 - lambda_split) * lin_split
return splits
def _light_view_matrix(light_dir: np.ndarray, centroid: np.ndarray) -> np.ndarray:
"""Row-major view matrix for an orthographic light camera."""
light_dir_n = light_dir / np.linalg.norm(light_dir)
up = np.array([0, 1, 0], dtype=np.float32)
if abs(np.dot(light_dir_n, up)) > 0.99:
up = np.array([1, 0, 0], dtype=np.float32)
right = np.cross(up, light_dir_n)
right /= np.linalg.norm(right)
up = np.cross(light_dir_n, right)
m = np.eye(4, dtype=np.float32)
m[0, :3] = right
m[1, :3] = up
m[2, :3] = light_dir_n
m[:3, 3] = -m[:3, :3] @ centroid
return m
def _ortho_matrix(mins: np.ndarray, maxs: np.ndarray) -> np.ndarray:
"""Row-major orthographic projection around an AABB. Z maps to [0, 1]."""
m = np.zeros((4, 4), dtype=np.float32)
m[0, 0] = 2.0 / (maxs[0] - mins[0])
m[1, 1] = 2.0 / (maxs[1] - mins[1])
m[2, 2] = 1.0 / (maxs[2] - mins[2])
m[0, 3] = -(maxs[0] + mins[0]) / (maxs[0] - mins[0])
m[1, 3] = -(maxs[1] + mins[1]) / (maxs[1] - mins[1])
m[2, 3] = -mins[2] / (maxs[2] - mins[2])
m[3, 3] = 1.0
return m
[docs]
def compute_cascades(
view: np.ndarray,
proj: np.ndarray,
light_dir: np.ndarray,
near: float | None = None,
far: float | None = None,
cascade_count: int = DEFAULT_CASCADE_COUNT,
lambda_split: float = DEFAULT_LAMBDA,
padding: float = DEFAULT_PADDING,
max_far: float = 300.0,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute CSM cascade splits and row-major light VP matrices.
``near`` / ``far`` default to values extracted from ``proj`` and are
clamped at ``max_far`` to keep shadow-map resolution reasonable for
wide-open scenes.
Returns ``(cascade_vps[cascade_count, 4, 4], splits[cascade_count + 1])``
where ``cascade_vps`` are row-major 4×4 (transpose at the GPU boundary).
"""
n: float
f: float
if near is None or far is None or near <= 0 or far <= 0:
p22, p23 = float(proj[2, 2]), float(proj[2, 3])
if abs(p22) > 1e-6:
near_auto = p23 / p22
far_auto = p23 / (p22 + 1.0)
else:
near_auto, far_auto = 0.1, 100.0
n = float(near) if (near and near > 0) else near_auto
f = float(far) if (far and far > 0) else far_auto
f = min(f, max_far)
else:
n, f = float(near), float(far)
splits = compute_splits(n, f, cascade_count, lambda_split)
inv_vp = np.linalg.inv(proj @ view)
cascade_vps = np.zeros((cascade_count, 4, 4), dtype=np.float32)
for c in range(cascade_count):
ndc_near = depth_to_ndc(splits[c], proj)
ndc_far = depth_to_ndc(splits[c + 1], proj)
corners_ndc = np.array([
[-1, -1, ndc_near, 1], [1, -1, ndc_near, 1],
[-1, 1, ndc_near, 1], [1, 1, ndc_near, 1],
[-1, -1, ndc_far, 1], [1, -1, ndc_far, 1],
[-1, 1, ndc_far, 1], [1, 1, ndc_far, 1],
], dtype=np.float32)
corners_world = (inv_vp @ corners_ndc.T).T
corners_world /= corners_world[:, 3:4]
corners_xyz = corners_world[:, :3]
centroid = corners_xyz.mean(axis=0)
light_view = _light_view_matrix(light_dir, centroid)
corners_ls = (light_view @ np.hstack([corners_xyz, np.ones((8, 1), dtype=np.float32)]).T).T
mins = corners_ls[:, :3].min(axis=0)
maxs = corners_ls[:, :3].max(axis=0)
pad = (maxs - mins) * padding
mins = mins - pad
maxs = maxs + pad
light_proj = _ortho_matrix(mins, maxs)
cascade_vps[c] = light_proj @ light_view
return cascade_vps, splits