Source code for simvx.core.hot_reload

"""
Hot-reload system -- watch script files for changes and reload them live.

On file change: serialize node state -> importlib.reload(module) -> instantiate
new class -> restore state.  Non-serializable state is warned and skipped.

Public API:
    from simvx.core.hot_reload import HotReloadManager

    mgr = HotReloadManager(tree)
    mgr.watch("game.py")          # start watching a module file
    mgr.poll()                    # call each frame (or on a timer)
"""

import importlib
import logging
import os
import sys
import types
from pathlib import Path
from typing import Any

from .descriptors import Signal
from .node import Node
from .scene_tree import SceneTree

log = logging.getLogger(__name__)

__all__ = ["HotReloadManager"]

def _serialize_settings(node: Node) -> dict[str, Any]:
    """Extract current Property values from a node into a plain dict."""
    state: dict[str, Any] = {}
    for name, prop in node.get_properties().items():
        try:
            val = getattr(node, prop.attr, prop.default)
            # Quick JSON-safe check: skip callables, generators, etc.
            if callable(val) and not isinstance(val, type):
                continue
            state[name] = val
        except (AttributeError, TypeError):
            log.debug("hot_reload: skipping non-serializable setting %s on %s", name, node.name)
    return state

def _serialize_node_state(node: Node) -> dict[str, Any]:
    """Capture serializable state from a node (settings + spatial props)."""
    data: dict[str, Any] = {
        "name": node.name,
        "class": type(node).__name__,
        "module": type(node).__module__,
        "settings": _serialize_settings(node),
        "groups": list(node._groups),
        "visible": node.visible,
    }

    # Spatial properties
    if hasattr(node, "position"):
        data["position"] = node.position
    if hasattr(node, "rotation"):
        data["rotation"] = node.rotation
    if hasattr(node, "scale") and not callable(node.scale):
        data["scale"] = node.scale
    if hasattr(node, "velocity"):
        data["velocity"] = node.velocity

    # Children (recursive)
    data["children"] = [_serialize_node_state(c) for c in node.children]
    return data

def _restore_settings(node: Node, state: dict[str, Any]):
    """Apply saved settings back to a node, warning on failures."""
    for name, val in state.items():
        try:
            if name in node.get_properties():
                setattr(node, name, val)
        except Exception as e:
            log.warning("hot_reload: failed to restore setting %s=%r on %s: %s", name, val, node.name, e)

def _restore_node_state(node: Node, state: dict[str, Any]):
    """Apply serialized state back to a node."""
    node.name = state.get("name", node.name)
    node.visible = state.get("visible", True)

    for group in state.get("groups", []):
        node.add_to_group(group)

    if "position" in state and hasattr(node, "position"):
        try:
            node.position = state["position"]
        except (KeyError, TypeError, ValueError):
            pass
    if "rotation" in state and hasattr(node, "rotation"):
        try:
            node.rotation = state["rotation"]
        except (KeyError, TypeError, ValueError):
            pass
    if "scale" in state and hasattr(node, "scale") and not callable(node.scale):
        try:
            node.scale = state["scale"]
        except (KeyError, TypeError, ValueError):
            pass
    if "velocity" in state and hasattr(node, "velocity"):
        try:
            node.velocity = state["velocity"]
        except (KeyError, TypeError, ValueError):
            pass

    _restore_settings(node, state.get("settings", {}))

class _WatchedFile:
    """Tracks a single file's modification time and associated module."""

    __slots__ = ("path", "module_name", "mtime")

    def __init__(self, path: str, module_name: str):
        self.path = path
        self.module_name = module_name
        self.mtime = self._stat()

    def _stat(self) -> float:
        try:
            return os.stat(self.path).st_mtime
        except OSError:
            return 0.0

    def changed(self) -> bool:
        """Return True if file was modified since last check."""
        new_mtime = self._stat()
        if new_mtime != self.mtime:
            self.mtime = new_mtime
            return True
        return False

[docs] class HotReloadManager: """Watches script files for changes and hot-reloads node classes. Usage: mgr = HotReloadManager(tree) mgr.watch("my_game.py") # In your game loop: mgr.poll() # checks every poll_interval seconds Attributes: enabled (bool): When False, :meth:`poll` returns immediately without checking files or reloading. Watched files are retained so toggling ``enabled`` back to True resumes change detection from the next poll. ``force_reload`` and ``reload_module`` ignore ``enabled``. Signals: module_reloaded(module_name: str, class_names: list[str]) Emitted after a module is successfully reloaded and live nodes updated. reload_failed(module_name: str, error: str) Emitted when a reload attempt fails. The old version remains active. node_reinstanced(old_node: Node, new_node: Node) Emitted when a node is replaced with a fresh instance of its (reloaded) class. """ def __init__(self, tree: SceneTree, poll_interval: float = 0.5): self.tree = tree self.poll_interval = poll_interval self.enabled: bool = True self.module_reloaded = Signal() # (module_name, class_names) self.reload_failed = Signal() # (module_name, error_str) self.node_reinstanced = Signal() # (old_node, new_node) self._watched: list[_WatchedFile] = [] self._time_accumulator: float = 0.0 # Map file path -> module name for files registered via watch self._path_to_module: dict[str, str] = {} # -- File watching ---------------------------------------------------------
[docs] def watch(self, file_path: str) -> None: """Start watching a Python file for changes. Args: file_path: Path to a .py file. The corresponding module must be importable. """ path = str(Path(file_path).resolve()) if not Path(path).is_file(): log.warning("hot_reload: file not found: %s", path) return # Find module name from sys.modules by matching file path module_name = self._find_module_name(path) if module_name is None: # Try importing from filename module_name = Path(path).stem # Add directory to sys.path if needed parent = str(Path(path).parent) if parent not in sys.path: sys.path.insert(0, parent) try: importlib.import_module(module_name) except ImportError: log.warning("hot_reload: could not import module for %s", path) return # Don't double-watch if any(w.path == path for w in self._watched): return self._watched.append(_WatchedFile(path, module_name)) self._path_to_module[path] = module_name log.debug("hot_reload: watching %s (%s)", path, module_name)
[docs] def unwatch(self, file_path: str) -> None: """Stop watching a file.""" path = str(Path(file_path).resolve()) self._watched = [w for w in self._watched if w.path != path] self._path_to_module.pop(path, None)
[docs] @property def watched_files(self) -> list[str]: """Return list of currently watched file paths.""" return [w.path for w in self._watched]
# -- Polling ---------------------------------------------------------------
[docs] def poll(self, dt: float = 0.0) -> list[str]: """Check watched files for changes. Call each frame with delta time. Returns an empty list immediately when :attr:`enabled` is False, so the per-frame call site stays cheap when hot reload is toggled off. Returns: List of module names that were reloaded. """ if not self.enabled: return [] self._time_accumulator += dt if self._time_accumulator < self.poll_interval: return [] self._time_accumulator = 0.0 reloaded: list[str] = [] for wf in self._watched: if wf.changed(): try: self._reload_module(wf.module_name) reloaded.append(wf.module_name) except Exception as e: error_msg = f"{type(e).__name__}: {e}" log.error("hot_reload: failed to reload %s: %s", wf.module_name, error_msg) self.reload_failed.emit(wf.module_name, error_msg) return reloaded
# -- Public reload API -----------------------------------------------------
[docs] def force_reload(self, module_name: str) -> bool: """Force-reload a specific module by name. Returns: True if reload succeeded, False otherwise. """ try: self._reload_module(module_name) return True except Exception as e: error_msg = f"{type(e).__name__}: {e}" log.error("hot_reload: failed to reload %s: %s", module_name, error_msg) self.reload_failed.emit(module_name, error_msg) return False
[docs] def reload_module(self, file_path: str) -> bool: """Reload a module by its file path. Reimports the module, finds changed classes, and updates live nodes. On failure, keeps the old version and emits ``reload_failed``. Args: file_path: Path to the .py file to reload. Returns: True if reload succeeded, False otherwise. """ path = str(Path(file_path).resolve()) module_name = self._path_to_module.get(path) or self._find_module_name(path) if module_name is None: module_name = Path(path).stem return self.force_reload(module_name)
# -- Node re-instantiation -------------------------------------------------
[docs] def reinstance_node(self, old_node: Node, new_class: type[Node]) -> Node: """Create a new instance of *new_class*, copying state and children from *old_node*. The new node replaces *old_node* in the tree (same parent, same position among siblings). Clears any ``_script_error`` flag so processing resumes. Emits ``node_reinstanced(old_node, new_node)``. Returns: The newly created node. """ state = _serialize_node_state(old_node) parent = old_node.parent # Create fresh instance new_node = new_class() _restore_node_state(new_node, state) new_node._script_error = False if parent is not None: # Find the old node's index among siblings idx = None for i, child in enumerate(parent.children): if child is old_node: idx = i break # Detach old, reparent children, insert new at same position children = list(old_node.children) old_node.clear_children() parent.remove_child(old_node) parent.add_child(new_node) # Move to original index if possible if idx is not None and idx < len(parent.children) - 1: # Re-order: remove from end and insert at idx parent.children._list.remove(new_node) parent.children._list.insert(idx, new_node) parent.children._dirty = True # Re-attach original children to new node for child in children: new_node.add_child(child) self.node_reinstanced.emit(old_node, new_node) log.info("hot_reload: reinstanced %s as %s", old_node.name, new_class.__name__) return new_node
# -- Internal reload logic ------------------------------------------------- def _reload_module(self, module_name: str): """Reload a module and replace live node instances with new class versions.""" module = sys.modules.get(module_name) if module is None: module = importlib.import_module(module_name) # Collect old classes from the module old_classes = { name: obj for name, obj in vars(module).items() if isinstance(obj, type) and issubclass(obj, Node) and obj.__module__ == module_name } # Find live nodes using old classes (including errored ones) live_nodes = self._find_live_nodes(old_classes) # Serialize state of live nodes before reload saved_states: dict[int, dict[str, Any]] = {} errored_nodes: dict[int, bool] = {} for node in live_nodes: saved_states[id(node)] = _serialize_node_state(node) errored_nodes[id(node)] = node._script_error # Invalidate bytecode cache so importlib.reload reads fresh source self._invalidate_bytecode_cache(module) # Reload the module -- on failure, the old module stays in sys.modules module = importlib.reload(module) # Get new classes new_classes = { name: obj for name, obj in vars(module).items() if isinstance(obj, type) and issubclass(obj, Node) and obj.__module__ == module_name } # Replace nodes with new class instances replaced_classes: list[str] = [] for node in live_nodes: state = saved_states.get(id(node)) if state is None: continue class_name = state["class"] new_cls = new_classes.get(class_name) if new_cls is None: continue was_errored = errored_nodes.get(id(node), False) # Swap the class on the existing instance (avoids re-parenting) node.__class__ = new_cls _restore_node_state(node, state) # Clear error flag so the node resumes processing after a fix if was_errored: node._script_error = False log.info("hot_reload: cleared error on %s (class %s reloaded)", node.name, class_name) if class_name not in replaced_classes: replaced_classes.append(class_name) if replaced_classes: log.info("hot_reload: reloaded %s -> %s", module_name, replaced_classes) self.module_reloaded.emit(module_name, replaced_classes) def _find_live_nodes(self, classes: dict[str, type]) -> list[Node]: """Find all nodes in the tree whose class is one of the given classes.""" if not self.tree.root: return [] class_set = set(classes.values()) return [node for node in self.tree.root.walk() if type(node) in class_set] @staticmethod def _find_module_name(path: str) -> str | None: """Find the module name in sys.modules that corresponds to a file path.""" for name, mod in sys.modules.items(): if mod is None: continue mod_file = getattr(mod, "__file__", None) if mod_file and Path(mod_file).resolve() == Path(path).resolve(): return name return None @staticmethod def _invalidate_bytecode_cache(module: types.ModuleType) -> None: """Delete the .pyc file for *module* so importlib.reload reads fresh source.""" src = getattr(module, "__file__", None) if not src: return try: cache_path = importlib.util.cache_from_source(src) if os.path.exists(cache_path): os.unlink(cache_path) except (NotImplementedError, OSError): pass