"""
Hot-reload system -- watch script files for changes and reload them live.
On file change: serialize node state -> importlib.reload(module) -> instantiate
new class -> restore state. Non-serializable state is warned and skipped.
Public API:
from simvx.core.hot_reload import HotReloadManager
mgr = HotReloadManager(tree)
mgr.watch("game.py") # start watching a module file
mgr.poll() # call each frame (or on a timer)
"""
import importlib
import logging
import os
import sys
import types
from pathlib import Path
from typing import Any
from .descriptors import Signal
from .node import Node
from .scene_tree import SceneTree
log = logging.getLogger(__name__)
__all__ = ["HotReloadManager"]
def _serialize_settings(node: Node) -> dict[str, Any]:
"""Extract current Property values from a node into a plain dict."""
state: dict[str, Any] = {}
for name, prop in node.get_properties().items():
try:
val = getattr(node, prop.attr, prop.default)
# Quick JSON-safe check: skip callables, generators, etc.
if callable(val) and not isinstance(val, type):
continue
state[name] = val
except (AttributeError, TypeError):
log.debug("hot_reload: skipping non-serializable setting %s on %s", name, node.name)
return state
def _serialize_node_state(node: Node) -> dict[str, Any]:
"""Capture serializable state from a node (settings + spatial props)."""
data: dict[str, Any] = {
"name": node.name,
"class": type(node).__name__,
"module": type(node).__module__,
"settings": _serialize_settings(node),
"groups": list(node._groups),
"visible": node.visible,
}
# Spatial properties
if hasattr(node, "position"):
data["position"] = node.position
if hasattr(node, "rotation"):
data["rotation"] = node.rotation
if hasattr(node, "scale") and not callable(node.scale):
data["scale"] = node.scale
if hasattr(node, "velocity"):
data["velocity"] = node.velocity
# Children (recursive)
data["children"] = [_serialize_node_state(c) for c in node.children]
return data
def _restore_settings(node: Node, state: dict[str, Any]):
"""Apply saved settings back to a node, warning on failures."""
for name, val in state.items():
try:
if name in node.get_properties():
setattr(node, name, val)
except Exception as e:
log.warning("hot_reload: failed to restore setting %s=%r on %s: %s", name, val, node.name, e)
def _restore_node_state(node: Node, state: dict[str, Any]):
"""Apply serialized state back to a node."""
node.name = state.get("name", node.name)
node.visible = state.get("visible", True)
for group in state.get("groups", []):
node.add_to_group(group)
if "position" in state and hasattr(node, "position"):
try:
node.position = state["position"]
except (KeyError, TypeError, ValueError):
pass
if "rotation" in state and hasattr(node, "rotation"):
try:
node.rotation = state["rotation"]
except (KeyError, TypeError, ValueError):
pass
if "scale" in state and hasattr(node, "scale") and not callable(node.scale):
try:
node.scale = state["scale"]
except (KeyError, TypeError, ValueError):
pass
if "velocity" in state and hasattr(node, "velocity"):
try:
node.velocity = state["velocity"]
except (KeyError, TypeError, ValueError):
pass
_restore_settings(node, state.get("settings", {}))
class _WatchedFile:
"""Tracks a single file's modification time and associated module."""
__slots__ = ("path", "module_name", "mtime")
def __init__(self, path: str, module_name: str):
self.path = path
self.module_name = module_name
self.mtime = self._stat()
def _stat(self) -> float:
try:
return os.stat(self.path).st_mtime
except OSError:
return 0.0
def changed(self) -> bool:
"""Return True if file was modified since last check."""
new_mtime = self._stat()
if new_mtime != self.mtime:
self.mtime = new_mtime
return True
return False
[docs]
class HotReloadManager:
"""Watches script files for changes and hot-reloads node classes.
Usage:
mgr = HotReloadManager(tree)
mgr.watch("my_game.py")
# In your game loop:
mgr.poll() # checks every poll_interval seconds
Attributes:
enabled (bool): When False, :meth:`poll` returns immediately without
checking files or reloading. Watched files are retained so toggling
``enabled`` back to True resumes change detection from the next
poll. ``force_reload`` and ``reload_module`` ignore ``enabled``.
Signals:
module_reloaded(module_name: str, class_names: list[str])
Emitted after a module is successfully reloaded and live nodes updated.
reload_failed(module_name: str, error: str)
Emitted when a reload attempt fails. The old version remains active.
node_reinstanced(old_node: Node, new_node: Node)
Emitted when a node is replaced with a fresh instance of its (reloaded) class.
"""
def __init__(self, tree: SceneTree, poll_interval: float = 0.5):
self.tree = tree
self.poll_interval = poll_interval
self.enabled: bool = True
self.module_reloaded = Signal() # (module_name, class_names)
self.reload_failed = Signal() # (module_name, error_str)
self.node_reinstanced = Signal() # (old_node, new_node)
self._watched: list[_WatchedFile] = []
self._time_accumulator: float = 0.0
# Map file path -> module name for files registered via watch
self._path_to_module: dict[str, str] = {}
# -- File watching ---------------------------------------------------------
[docs]
def watch(self, file_path: str) -> None:
"""Start watching a Python file for changes.
Args:
file_path: Path to a .py file. The corresponding module must be importable.
"""
path = str(Path(file_path).resolve())
if not Path(path).is_file():
log.warning("hot_reload: file not found: %s", path)
return
# Find module name from sys.modules by matching file path
module_name = self._find_module_name(path)
if module_name is None:
# Try importing from filename
module_name = Path(path).stem
# Add directory to sys.path if needed
parent = str(Path(path).parent)
if parent not in sys.path:
sys.path.insert(0, parent)
try:
importlib.import_module(module_name)
except ImportError:
log.warning("hot_reload: could not import module for %s", path)
return
# Don't double-watch
if any(w.path == path for w in self._watched):
return
self._watched.append(_WatchedFile(path, module_name))
self._path_to_module[path] = module_name
log.debug("hot_reload: watching %s (%s)", path, module_name)
[docs]
def unwatch(self, file_path: str) -> None:
"""Stop watching a file."""
path = str(Path(file_path).resolve())
self._watched = [w for w in self._watched if w.path != path]
self._path_to_module.pop(path, None)
[docs]
@property
def watched_files(self) -> list[str]:
"""Return list of currently watched file paths."""
return [w.path for w in self._watched]
# -- Polling ---------------------------------------------------------------
[docs]
def poll(self, dt: float = 0.0) -> list[str]:
"""Check watched files for changes. Call each frame with delta time.
Returns an empty list immediately when :attr:`enabled` is False, so
the per-frame call site stays cheap when hot reload is toggled off.
Returns:
List of module names that were reloaded.
"""
if not self.enabled:
return []
self._time_accumulator += dt
if self._time_accumulator < self.poll_interval:
return []
self._time_accumulator = 0.0
reloaded: list[str] = []
for wf in self._watched:
if wf.changed():
try:
self._reload_module(wf.module_name)
reloaded.append(wf.module_name)
except Exception as e:
error_msg = f"{type(e).__name__}: {e}"
log.error("hot_reload: failed to reload %s: %s", wf.module_name, error_msg)
self.reload_failed.emit(wf.module_name, error_msg)
return reloaded
# -- Public reload API -----------------------------------------------------
[docs]
def force_reload(self, module_name: str) -> bool:
"""Force-reload a specific module by name.
Returns:
True if reload succeeded, False otherwise.
"""
try:
self._reload_module(module_name)
return True
except Exception as e:
error_msg = f"{type(e).__name__}: {e}"
log.error("hot_reload: failed to reload %s: %s", module_name, error_msg)
self.reload_failed.emit(module_name, error_msg)
return False
[docs]
def reload_module(self, file_path: str) -> bool:
"""Reload a module by its file path.
Reimports the module, finds changed classes, and updates live nodes.
On failure, keeps the old version and emits ``reload_failed``.
Args:
file_path: Path to the .py file to reload.
Returns:
True if reload succeeded, False otherwise.
"""
path = str(Path(file_path).resolve())
module_name = self._path_to_module.get(path) or self._find_module_name(path)
if module_name is None:
module_name = Path(path).stem
return self.force_reload(module_name)
# -- Node re-instantiation -------------------------------------------------
[docs]
def reinstance_node(self, old_node: Node, new_class: type[Node]) -> Node:
"""Create a new instance of *new_class*, copying state and children from *old_node*.
The new node replaces *old_node* in the tree (same parent, same position among siblings).
Clears any ``_script_error`` flag so processing resumes.
Emits ``node_reinstanced(old_node, new_node)``.
Returns:
The newly created node.
"""
state = _serialize_node_state(old_node)
parent = old_node.parent
# Create fresh instance
new_node = new_class()
_restore_node_state(new_node, state)
new_node._script_error = False
if parent is not None:
# Find the old node's index among siblings
idx = None
for i, child in enumerate(parent.children):
if child is old_node:
idx = i
break
# Detach old, reparent children, insert new at same position
children = list(old_node.children)
old_node.clear_children()
parent.remove_child(old_node)
parent.add_child(new_node)
# Move to original index if possible
if idx is not None and idx < len(parent.children) - 1:
# Re-order: remove from end and insert at idx
parent.children._list.remove(new_node)
parent.children._list.insert(idx, new_node)
parent.children._dirty = True
# Re-attach original children to new node
for child in children:
new_node.add_child(child)
self.node_reinstanced.emit(old_node, new_node)
log.info("hot_reload: reinstanced %s as %s", old_node.name, new_class.__name__)
return new_node
# -- Internal reload logic -------------------------------------------------
def _reload_module(self, module_name: str):
"""Reload a module and replace live node instances with new class versions."""
module = sys.modules.get(module_name)
if module is None:
module = importlib.import_module(module_name)
# Collect old classes from the module
old_classes = {
name: obj
for name, obj in vars(module).items()
if isinstance(obj, type) and issubclass(obj, Node) and obj.__module__ == module_name
}
# Find live nodes using old classes (including errored ones)
live_nodes = self._find_live_nodes(old_classes)
# Serialize state of live nodes before reload
saved_states: dict[int, dict[str, Any]] = {}
errored_nodes: dict[int, bool] = {}
for node in live_nodes:
saved_states[id(node)] = _serialize_node_state(node)
errored_nodes[id(node)] = node._script_error
# Invalidate bytecode cache so importlib.reload reads fresh source
self._invalidate_bytecode_cache(module)
# Reload the module -- on failure, the old module stays in sys.modules
module = importlib.reload(module)
# Get new classes
new_classes = {
name: obj
for name, obj in vars(module).items()
if isinstance(obj, type) and issubclass(obj, Node) and obj.__module__ == module_name
}
# Replace nodes with new class instances
replaced_classes: list[str] = []
for node in live_nodes:
state = saved_states.get(id(node))
if state is None:
continue
class_name = state["class"]
new_cls = new_classes.get(class_name)
if new_cls is None:
continue
was_errored = errored_nodes.get(id(node), False)
# Swap the class on the existing instance (avoids re-parenting)
node.__class__ = new_cls
_restore_node_state(node, state)
# Clear error flag so the node resumes processing after a fix
if was_errored:
node._script_error = False
log.info("hot_reload: cleared error on %s (class %s reloaded)", node.name, class_name)
if class_name not in replaced_classes:
replaced_classes.append(class_name)
if replaced_classes:
log.info("hot_reload: reloaded %s -> %s", module_name, replaced_classes)
self.module_reloaded.emit(module_name, replaced_classes)
def _find_live_nodes(self, classes: dict[str, type]) -> list[Node]:
"""Find all nodes in the tree whose class is one of the given classes."""
if not self.tree.root:
return []
class_set = set(classes.values())
return [node for node in self.tree.root.walk() if type(node) in class_set]
@staticmethod
def _find_module_name(path: str) -> str | None:
"""Find the module name in sys.modules that corresponds to a file path."""
for name, mod in sys.modules.items():
if mod is None:
continue
mod_file = getattr(mod, "__file__", None)
if mod_file and Path(mod_file).resolve() == Path(path).resolve():
return name
return None
@staticmethod
def _invalidate_bytecode_cache(module: types.ModuleType) -> None:
"""Delete the .pyc file for *module* so importlib.reload reads fresh source."""
src = getattr(module, "__file__", None)
if not src:
return
try:
cache_path = importlib.util.cache_from_source(src)
if os.path.exists(cache_path):
os.unlink(cache_path)
except (NotImplementedError, OSError):
pass