"""Skeleton and bone hierarchy for skeletal animation.
Skeleton is a Node3D that manages a flat array of Bone data. Each frame it
walks the parent chain, computes world transforms, and produces GPU-ready
joint matrices (world * inverse_bind) suitable for SSBO upload.
SkeletonProfile provides a standard naming convention so retargeting,
animation libraries, and humanoid IK can agree on bone names.
"""
import logging
from dataclasses import dataclass, field
import numpy as np
from .descriptors import Signal
from .nodes_3d.node3d import Node3D
log = logging.getLogger(__name__)
__all__ = [
"Bone",
"Skeleton",
"SkeletonProfile",
"PROFILE_HUMANOID",
]
# ============================================================================
# Bone — data record (no scene-tree node, just structured data)
# ============================================================================
[docs]
@dataclass
class Bone:
"""Single bone in a skeleton hierarchy."""
name: str = ""
parent_index: int = -1
inverse_bind_matrix: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32))
local_transform: np.ndarray = field(default_factory=lambda: np.eye(4, dtype=np.float32))
# ============================================================================
# Skeleton — Node3D managing a bone hierarchy
# ============================================================================
[docs]
class Skeleton(Node3D):
"""Bone hierarchy with GPU-ready joint matrix computation.
Joint matrices = parent_world * local_transform * inverse_bind_matrix
These are uploaded to an SSBO for vertex skinning in the shader.
As a ``Node3D`` it participates in the scene tree, inherits a 3D
transform, and can be animated/parented like any other spatial node.
Signals:
bone_pose_changed: Emitted after ``compute_pose`` updates joint
matrices. Connected ``MeshInstance3D`` nodes (via ``skin``)
can listen to know when to re-upload skinning data.
"""
def __init__(self, bones: list[Bone] | None = None, **kwargs):
super().__init__(**kwargs)
self.bones: list[Bone] = bones or []
self._joint_matrices: np.ndarray | None = None
self._world_transforms: np.ndarray | None = None
self._bone_overrides: dict[int, np.ndarray] = {}
self.bone_pose_changed = Signal()
# -- Bone count -----------------------------------------------------------
@property
def bone_count(self) -> int:
return len(self.bones)
# -- Joint matrices -------------------------------------------------------
@property
def joint_matrices(self) -> np.ndarray:
"""Get computed joint matrices (bone_count, 4, 4). Call compute_pose() first."""
if self._joint_matrices is None:
self._joint_matrices = np.zeros((self.bone_count, 4, 4), dtype=np.float32)
for i in range(self.bone_count):
self._joint_matrices[i] = np.eye(4, dtype=np.float32)
return self._joint_matrices
# -- Pose management ------------------------------------------------------
[docs]
def set_bone_pose(self, bone_index: int, transform: np.ndarray) -> None:
"""Override a single bone's local transform for the current pose.
The override persists until cleared via ``clear_bone_pose`` or another
``set_bone_pose`` call. Call ``compute_pose()`` afterwards (or let
``process()`` do it) to propagate the change.
"""
self._bone_overrides[bone_index] = np.asarray(transform, dtype=np.float32).reshape(4, 4)
[docs]
def get_bone_pose(self, bone_index: int) -> np.ndarray:
"""Return the current local transform for a bone (override or default)."""
if bone_index in self._bone_overrides:
return self._bone_overrides[bone_index].copy()
if 0 <= bone_index < self.bone_count:
return self.bones[bone_index].local_transform.copy()
raise IndexError(f"bone_index {bone_index} out of range (bone_count={self.bone_count})")
[docs]
def clear_bone_pose(self, bone_index: int) -> None:
"""Remove the per-bone pose override, reverting to the bone's default local_transform."""
self._bone_overrides.pop(bone_index, None)
[docs]
def clear_all_bone_poses(self) -> None:
"""Remove all per-bone pose overrides."""
self._bone_overrides.clear()
# -- Pose computation -----------------------------------------------------
[docs]
def compute_pose(self, bone_transforms: dict[int, np.ndarray] | None = None) -> None:
"""Compute final joint matrices from bone-local transforms.
Args:
bone_transforms: Optional *additional* overrides for bone-local
transforms (merged on top of ``set_bone_pose`` overrides).
Maps bone_index -> 4x4 local transform matrix.
Bones not in any override use their default ``local_transform``.
"""
n = self.bone_count
if n == 0:
return
if self._joint_matrices is None or len(self._joint_matrices) != n:
self._joint_matrices = np.zeros((n, 4, 4), dtype=np.float32)
# Merge overrides: set_bone_pose overrides first, then explicit arg on top
merged = dict(self._bone_overrides)
if bone_transforms:
merged.update(bone_transforms)
# Compute world transforms by walking hierarchy (parents before children)
world = np.zeros((n, 4, 4), dtype=np.float32)
for i, bone in enumerate(self.bones):
local = merged[i] if i in merged else bone.local_transform
if bone.parent_index < 0:
world[i] = local
elif bone.parent_index >= n:
log.warning("Bone %s has invalid parent_index %s (bone_count=%s)", i, bone.parent_index, n)
world[i] = local
else:
world[i] = world[bone.parent_index] @ local
# Joint matrix = world * inverse_bind
self._joint_matrices[i] = world[i] @ bone.inverse_bind_matrix
# Cache world transforms for get_bone_global_transform
self._world_transforms = world
# Notify listeners (e.g. skinned MeshInstance3D nodes)
self.bone_pose_changed()
# -- Scene-tree integration -----------------------------------------------
[docs]
def process(self, dt: float) -> None:
"""Auto-recompute pose each frame if there are any overrides."""
if self._bone_overrides:
self.compute_pose()
# -- Bone lookup ----------------------------------------------------------
[docs]
def find_bone(self, name: str) -> int:
"""Find bone index by name. Returns -1 if not found."""
for i, bone in enumerate(self.bones):
if bone.name == name:
return i
log.warning("Bone %r not found in skeleton (%s bones)", name, self.bone_count)
return -1
[docs]
def add_bone(self, bone: Bone) -> int:
"""Append a bone and return its index."""
idx = len(self.bones)
self.bones.append(bone)
# Invalidate cached matrices so they resize on next access
self._joint_matrices = None
return idx
# ============================================================================
# SkeletonProfile — standard bone naming convention
# ============================================================================
[docs]
@dataclass
class SkeletonProfile:
"""Standard bone naming convention for retargeting and humanoid IK.
A profile defines an ordered list of bone names, optional parent
relationships (by name), and bone groups for logical grouping.
Usage::
profile = PROFILE_HUMANOID
idx = skel.find_bone(profile.bone_names[0]) # "Hips"
# Validate a skeleton against the profile
missing = profile.validate(skel)
"""
name: str
bone_names: list[str] = field(default_factory=list)
bone_parents: dict[str, str] = field(default_factory=dict)
bone_groups: dict[str, list[str]] = field(default_factory=dict)
[docs]
def validate(self, skeleton: Skeleton) -> list[str]:
"""Return a list of profile bone names missing from *skeleton*."""
return [name for name in self.bone_names if skeleton.find_bone(name) == -1]
[docs]
def find_in_skeleton(self, skeleton: Skeleton, profile_bone_name: str) -> int:
"""Look up a profile bone name in a skeleton. Returns -1 if not found."""
return skeleton.find_bone(profile_bone_name)
[docs]
def get_parent(self, bone_name: str) -> str | None:
"""Return the profile-defined parent bone name, or None for root bones."""
return self.bone_parents.get(bone_name)
[docs]
def get_group(self, group_name: str) -> list[str]:
"""Return bone names belonging to a group, or empty list."""
return self.bone_groups.get(group_name, [])
# -- Built-in humanoid profile -----------------------------------------------
_HUMANOID_BONES = [
"Hips",
"Spine",
"Chest",
"UpperChest",
"Neck",
"Head",
"LeftShoulder",
"LeftUpperArm",
"LeftLowerArm",
"LeftHand",
"RightShoulder",
"RightUpperArm",
"RightLowerArm",
"RightHand",
"LeftUpperLeg",
"LeftLowerLeg",
"LeftFoot",
"RightUpperLeg",
"RightLowerLeg",
"RightFoot",
]
_HUMANOID_PARENTS = {
"Spine": "Hips",
"Chest": "Spine",
"UpperChest": "Chest",
"Neck": "UpperChest",
"Head": "Neck",
"LeftShoulder": "UpperChest",
"LeftUpperArm": "LeftShoulder",
"LeftLowerArm": "LeftUpperArm",
"LeftHand": "LeftLowerArm",
"RightShoulder": "UpperChest",
"RightUpperArm": "RightShoulder",
"RightLowerArm": "RightUpperArm",
"RightHand": "RightLowerArm",
"LeftUpperLeg": "Hips",
"LeftLowerLeg": "LeftUpperLeg",
"LeftFoot": "LeftLowerLeg",
"RightUpperLeg": "Hips",
"RightLowerLeg": "RightUpperLeg",
"RightFoot": "RightLowerLeg",
}
_HUMANOID_GROUPS = {
"torso": ["Hips", "Spine", "Chest", "UpperChest", "Neck", "Head"],
"left_arm": ["LeftShoulder", "LeftUpperArm", "LeftLowerArm", "LeftHand"],
"right_arm": ["RightShoulder", "RightUpperArm", "RightLowerArm", "RightHand"],
"left_leg": ["LeftUpperLeg", "LeftLowerLeg", "LeftFoot"],
"right_leg": ["RightUpperLeg", "RightLowerLeg", "RightFoot"],
}
PROFILE_HUMANOID = SkeletonProfile(
name="Humanoid",
bone_names=_HUMANOID_BONES,
bone_parents=_HUMANOID_PARENTS,
bone_groups=_HUMANOID_GROUPS,
)