"""ScriptManager — file-based class script loading for nodes.
Scripts are standard Python classes that extend the node's type. ScriptManager
imports the module, looks up the class by name, and swaps ``node.__class__``
so the node gains the script's methods (ready, process, etc.).
Script references use the ``path::ClassName`` format (e.g. ``"player.py::Player"``).
The class name is required — there is no guessing/searching.
Public API:
ScriptManager.load(node, project_dir)
ScriptManager.unload(node)
ScriptManager.reload(node, project_dir)
ScriptManager.load_tree(root, project_dir)
parse_script_ref(script) -> (path, class_name)
"""
import ast
import importlib
import logging
from pathlib import Path
from types import ModuleType
from .node import Node
log = logging.getLogger(__name__)
__all__ = ["ScriptManager", "parse_script_ref"]
[docs]
def parse_script_ref(script: str) -> tuple[str, str | None]:
"""Parse a script reference into (file_path, class_name).
``"player.py::Player"`` → ``("player.py", "Player")``
``"player.py"`` (legacy, no ``::``)) → ``("player.py", None)``
"""
if "::" in script:
path, class_name = script.rsplit("::", 1)
return path, class_name
return script, None
def _find_node_subclass_in_source(source: str) -> str | None:
"""Parse Python source with ast to find the single Node subclass defined.
Returns the class name if exactly one Node subclass is found, None otherwise.
"""
try:
tree = ast.parse(source)
except SyntaxError:
return None
candidates = []
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef) and node.bases:
candidates.append(node.name)
if len(candidates) == 1:
return candidates[0]
return None
[docs]
class ScriptManager:
"""Static manager for loading/unloading file-based class scripts on nodes."""
_module_cache: dict[str, ModuleType] = {}
_script_classes: dict[str, set[str]] = {} # source key -> {class_name, ...}
[docs]
@classmethod
def load(cls, node: Node, project_dir: str = "") -> bool:
"""Load a file-based or embedded script onto *node*.
For file-backed scripts: parses ``node.script`` for ``path::ClassName``,
resolves the path relative to *project_dir*, imports the module, looks up
the class by name, and swaps ``node.__class__``.
For embedded scripts: registers source with EmbeddedScriptFinder, imports
it, finds the class via AST, and applies the same class swap.
Returns True on success, False on error (logged, not raised).
"""
# Embedded script takes priority if no file path
embedded = getattr(node, "_script_embedded", None)
if embedded and not node.script:
return cls._load_embedded(node, embedded)
if not node.script:
return False
script_path, class_name = parse_script_ref(node.script)
path = cls._resolve_path(script_path, project_dir)
if path is not None and path.is_file():
# File exists — always import from it (supports reload)
pass
elif class_name and class_name in Node._registry:
# No file — use registry (class already known from a previous import or Python definition)
target_cls = Node._registry[class_name]
node._script_original_class = type(node)
node.__class__ = target_cls
return True
else:
log.error("ScriptManager: file not found: %s", script_path)
return False
module = cls._import_module(str(path))
if module is None:
return False
# Look up the class by explicit name
if class_name:
target_cls = getattr(module, class_name, None)
if target_cls is None:
log.error("ScriptManager: class %r not found in %s", class_name, path.name)
return False
else:
# Legacy: no class name in script ref — find classes defined in this module
defined = cls._find_defined_classes(module)
if len(defined) == 1:
target_cls = defined[0]
# Auto-upgrade the script ref to include the class name
node.script = f"{script_path}::{target_cls.__name__}"
elif len(defined) == 0:
log.error("ScriptManager: no Node subclass defined in %s", path.name)
return False
else:
names = [c.__name__ for c in defined]
log.error(
"ScriptManager: multiple classes in %s (%s) — use 'path::ClassName' format",
path.name,
", ".join(names),
)
return False
# Store original class for unload
node._script_original_class = type(node)
node._script_module = module
node.__class__ = target_cls
return True
@classmethod
def _find_defined_classes(cls, module: ModuleType) -> list[type]:
"""Return Node subclasses *defined* in this module (not imported)."""
result = []
for obj in vars(module).values():
if not isinstance(obj, type) or not issubclass(obj, Node):
continue
if obj is Node:
continue
# Only classes whose __module__ matches this module
if obj.__module__ == module.__name__:
result.append(obj)
return result
[docs]
@classmethod
def unload(cls, node: Node) -> None:
"""Restore the node's original class, removing the script behavior."""
original = getattr(node, "_script_original_class", None)
if original is not None:
node.__class__ = original
node._script_original_class = None
node._script_module = None
[docs]
@classmethod
def reload(cls, node: Node, project_dir: str = "") -> bool:
"""Reload a node's script (re-import module, re-swap class).
Preserves Property values across the reload.
"""
if not node.script:
return False
# Serialize property state
state = {}
for name, prop in node.get_properties().items():
try:
state[name] = getattr(node, prop.attr, prop.default)
except (AttributeError, TypeError):
pass
# Invalidate cache and reload
script_path, _ = parse_script_ref(node.script)
path = cls._resolve_path(script_path, project_dir)
if path:
cls.invalidate(str(path))
cls.unload(node)
ok = cls.load(node, project_dir)
# Restore properties
if ok:
for name, val in state.items():
try:
if name in node.get_properties():
setattr(node, name, val)
except (TypeError, AttributeError, ValueError):
pass
return ok
[docs]
@classmethod
def load_tree(cls, root: Node, project_dir: str = "") -> list[Node]:
"""Walk *root*'s tree and load file-based scripts on every node that has one.
Returns the list of nodes whose class was swapped so the caller can invoke
``ready()`` or other lifecycle methods.
"""
loaded: list[Node] = []
cls._walk_load(root, project_dir, loaded)
return loaded
[docs]
@classmethod
def invalidate(cls, abs_path: str) -> None:
"""Remove a module from the cache and its registry entries."""
abs_path = str(Path(abs_path).resolve())
cls._module_cache.pop(abs_path, None)
cls._unregister_script_classes(abs_path)
[docs]
@classmethod
def clear_cache(cls) -> None:
"""Clear the entire module cache and all script-originated registry entries."""
cls._module_cache.clear()
for key in list(cls._script_classes):
cls._unregister_script_classes(key)
@classmethod
def _unregister_script_classes(cls, source_key: str) -> None:
"""Remove Node._registry entries that originated from *source_key*."""
for class_name in cls._script_classes.pop(source_key, ()):
reg_cls = Node._registry.get(class_name)
if reg_cls is None:
continue
mod = getattr(reg_cls, "__module__", "")
if mod.startswith("_simvx_script_") or mod.startswith("_simvx_embed_"):
del Node._registry[class_name]
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@classmethod
def _resolve_path(cls, script_path: str, project_dir: str) -> Path | None:
"""Resolve a script path (absolute or relative to project_dir)."""
p = Path(script_path)
if p.is_absolute():
return p
if project_dir:
return Path(project_dir) / p
return p
@classmethod
def _import_module(cls, abs_path: str) -> ModuleType | None:
"""Import a .py file, caching the result.
Compiles from source directly to bypass Python's bytecode cache (.pyc),
which can return stale code when a file is modified within the same second.
"""
abs_path = str(Path(abs_path).resolve())
if abs_path in cls._module_cache:
return cls._module_cache[abs_path]
mod_name = f"_simvx_script_{Path(abs_path).stem}_{id(abs_path)}"
try:
source = Path(abs_path).read_text()
code = compile(source, abs_path, "exec")
module = ModuleType(mod_name)
module.__file__ = abs_path
module.__loader__ = type("_Loader", (), {"get_source": staticmethod(lambda _name: source)})() # type: ignore[assignment]
exec(code, module.__dict__) # noqa: S102
cls._module_cache[abs_path] = module
# Track which Node subclasses this script defined
cls._script_classes[abs_path] = {
name for name, obj in vars(module).items()
if isinstance(obj, type) and issubclass(obj, Node) and obj is not Node and obj.__module__ == mod_name
}
return module
except Exception as e:
log.error("ScriptManager: import failed for %s: %s", abs_path, e)
return None
@classmethod
def _load_embedded(cls, node: Node, source: str) -> bool:
"""Load an embedded script by registering it as a virtual module."""
from .script_embed import EmbeddedScriptFinder
# Find the class name via AST
class_name = _find_node_subclass_in_source(source)
# Deterministic module name from node identity
mod_name = f"_simvx_embed_{id(node):x}"
EmbeddedScriptFinder.register(mod_name, source)
try:
module = importlib.import_module(mod_name)
except Exception as e:
log.error("ScriptManager: embedded script failed for '%s': %s", node.name, e)
EmbeddedScriptFinder.unregister(mod_name)
return False
if class_name:
target_cls = getattr(module, class_name, None)
else:
# Fallback: find defined classes
defined = cls._find_defined_classes(module)
if len(defined) == 1:
target_cls = defined[0]
elif len(defined) == 0:
log.error("ScriptManager: no Node subclass in embedded script for %s", type(node).__name__)
EmbeddedScriptFinder.unregister(mod_name)
return False
else:
names = [c.__name__ for c in defined]
log.error(
"ScriptManager: multiple classes in embedded script (%s) — cannot determine which to use",
", ".join(names),
)
EmbeddedScriptFinder.unregister(mod_name)
return False
if target_cls is None:
log.error("ScriptManager: class %r not found in embedded script for %s", class_name, node.name)
EmbeddedScriptFinder.unregister(mod_name)
return False
# Track which classes this embedded script defined
cls._script_classes[mod_name] = {
name for name, obj in vars(module).items()
if isinstance(obj, type) and issubclass(obj, Node) and obj is not Node and obj.__module__ == mod_name
}
node._script_original_class = type(node)
node._script_module = module
node.__class__ = target_cls
return True
@classmethod
def _walk_load(cls, node: Node, project_dir: str, loaded: list[Node]) -> None:
"""Recursively load file-based and embedded scripts (children first, then node)."""
for child in list(node.children):
cls._walk_load(child, project_dir, loaded)
has_embedded = getattr(node, "_script_embedded", None)
if node.script or has_embedded:
if cls.load(node, project_dir):
loaded.append(node)