Source code for simvx.core.cst_codegen

"""CST-aware code generation -- read/write Python scene files preserving structure.

Provides utilities for detecting scene scripts, generating updated Python source
from live node trees, and checking for ambiguous procedural patterns.

Public API:
    is_scene_script(path)       -- check if a Python file defines a scene
    codegen_scene_file(path)    -- generate updated source from a scene file
    has_ambiguities(source, class_name) -- detect procedural add_child patterns
    save_tree_to_source(root)   -- generate Python source from a live node tree
"""

import ast
import importlib
import importlib.util
import logging
import sys
from pathlib import Path

import numpy as np

from .math.types import Quat, Vec2, Vec3
from .node import Node

log = logging.getLogger(__name__)

__all__ = [
    "is_scene_script",
    "codegen_scene_file",
    "has_ambiguities",
    "save_tree_to_source",
]

# Node base classes that indicate a scene class
_NODE_BASE_NAMES = {
    "Node", "Node2D", "Node3D", "Camera2D", "Camera3D", "OrbitCamera3D",
    "MeshInstance3D", "Light3D", "DirectionalLight3D", "PointLight3D", "SpotLight3D",
    "Text2D", "Text3D", "Timer", "Line2D", "Polygon2D",
    "CharacterBody2D", "CharacterBody3D", "CollisionShape2D", "CollisionShape3D",
    "Area2D", "Area3D", "CanvasLayer", "ParallaxBackground", "ParallaxLayer",
    "CanvasModulate", "YSortContainer",
    "AudioStreamPlayer", "AudioStreamPlayer2D", "AudioStreamPlayer3D",
    "Sprite2D", "AnimatedSprite2D", "ParticleEmitter",
    "RigidBody2D", "RigidBody3D", "StaticBody2D", "StaticBody3D",
    "Control", "Panel", "Label", "Button", "TextEdit",
}

# ============================================================================
# is_scene_script -- lightweight AST detection
# ============================================================================

[docs] def is_scene_script(path: str | Path) -> bool: """Check if a Python file defines a scene (contains Node subclass with __init__). Uses stdlib ``ast`` for lightweight parsing -- no imports needed. Returns False for non-existent files or parse errors. """ path = Path(path) if not path.is_file(): return False try: source = path.read_text() tree = ast.parse(source) except (OSError, SyntaxError): return False for node in ast.walk(tree): if not isinstance(node, ast.ClassDef): continue if not _has_node_base(node): continue # Check for __init__ with add_child or Property-style patterns for item in node.body: if isinstance(item, ast.FunctionDef) and item.name == "__init__": return True # Also accept classes without __init__ if they inherit from a Node type # (they use the parent's __init__ with Property descriptors) if _has_property_descriptors(node): return True return False
def _has_node_base(cls_def: ast.ClassDef) -> bool: """Check if a ClassDef inherits from a known Node type.""" for base in cls_def.bases: name = _ast_name(base) if name and name in _NODE_BASE_NAMES: return True return False def _has_property_descriptors(cls_def: ast.ClassDef) -> bool: """Check if a ClassDef has Property(...) assignments at class level.""" for item in cls_def.body: if isinstance(item, ast.Assign): for target in item.targets: if isinstance(target, ast.Name) and isinstance(item.value, ast.Call): func = _ast_name(item.value.func) if func == "Property": return True return False def _ast_name(node: ast.expr) -> str | None: """Extract a simple name from an AST node (Name or Attribute).""" if isinstance(node, ast.Name): return node.id if isinstance(node, ast.Attribute): return node.attr return None # ============================================================================ # codegen_scene_file -- import + instantiate + regenerate # ============================================================================
[docs] def codegen_scene_file(path: str | Path) -> str | None: """Generate updated Python source from a scene file. Imports the module, finds the primary Node class, instantiates it, reads the live tree, and generates updated source via save_tree_to_source(). Returns the new source string (does NOT write to disk). Returns None on failure. """ path = Path(path) if not path.is_file(): log.error("codegen_scene_file: file not found: %s", path) return None try: # Load the module from the file path module_name = f"_cst_codegen_tmp_{path.stem}" spec = importlib.util.spec_from_file_location(module_name, str(path)) if spec is None or spec.loader is None: log.error("codegen_scene_file: cannot create module spec for %s", path) return None module = importlib.util.module_from_spec(spec) # Temporarily add to sys.modules so relative imports can resolve old_module = sys.modules.get(module_name) sys.modules[module_name] = module try: spec.loader.exec_module(module) finally: if old_module is None: sys.modules.pop(module_name, None) else: sys.modules[module_name] = old_module # Find the primary Node subclass (first one defined in the module) node_cls = _find_primary_node_class(module, path) if node_cls is None: log.error("codegen_scene_file: no Node subclass found in %s", path) return None # Instantiate the scene tree root = node_cls() # Collect imports from the original source source = path.read_text() imports = _extract_imports(source) return save_tree_to_source(root, class_name=node_cls.__name__, imports=imports) except Exception: log.exception("codegen_scene_file: failed to process %s", path) return None
def _find_primary_node_class(module, path: Path) -> type | None: """Find the primary Node subclass defined in a module. Returns the first class that (a) is a Node subclass and (b) is defined in the given file (not imported from elsewhere). """ # Parse AST to get class names defined in the file try: source = path.read_text() tree = ast.parse(source) except (OSError, SyntaxError): return None defined_names = { node.name for node in ast.walk(tree) if isinstance(node, ast.ClassDef) } for name in defined_names: obj = getattr(module, name, None) if obj is not None and isinstance(obj, type) and issubclass(obj, Node) and obj is not Node: return obj return None def _extract_imports(source: str) -> list[str]: """Extract import lines from Python source.""" try: tree = ast.parse(source) except SyntaxError: return [] lines = source.splitlines() import_lines: list[str] = [] for node in ast.iter_child_nodes(tree): if isinstance(node, ast.Import | ast.ImportFrom): # Get the original source lines for this import start = node.lineno - 1 end = node.end_lineno if node.end_lineno else node.lineno import_lines.append("\n".join(lines[start:end])) return import_lines # ============================================================================ # has_ambiguities -- detect procedural patterns # ============================================================================
[docs] def has_ambiguities(source: str, class_name: str | None = None) -> bool: """Check if __init__ has procedural code (loops, conditionals with add_child). Uses stdlib ``ast`` to detect patterns like ``for ... add_child()`` or ``if ... add_child()`` that cannot be round-tripped losslessly. """ try: tree = ast.parse(source) except SyntaxError: return True # If we can't parse it, assume ambiguous for node in ast.walk(tree): if not isinstance(node, ast.ClassDef): continue if class_name and node.name != class_name: continue if not _has_node_base(node): continue for item in node.body: if isinstance(item, ast.FunctionDef) and item.name == "__init__": if _init_has_procedural_add_child(item): return True return False
def _init_has_procedural_add_child(func_def: ast.FunctionDef) -> bool: """Check if an __init__ method has add_child calls inside loops or conditionals.""" for node in ast.walk(func_def): if not isinstance(node, ast.For | ast.While | ast.If): continue # Check if any add_child call exists inside this control flow for inner in ast.walk(node): if isinstance(inner, ast.Call): func = inner.func if isinstance(func, ast.Attribute) and func.attr == "add_child": return True return False # ============================================================================ # save_tree_to_source -- live node tree to Python source # ============================================================================
[docs] def save_tree_to_source( root: Node, class_name: str | None = None, imports: list[str] | None = None, ) -> str: """Generate Python source from a live node tree. Produces a Python class with an ``__init__`` that reconstructs the tree via ``add_child()`` calls, writing only non-default kwargs. Args: root: The root node of the tree to serialize. class_name: Name of the generated class. Defaults to ``type(root).__name__``. imports: Optional list of import lines to include at the top. Returns: Complete Python source string. """ # Use the pre-script type name — if a user script has replaced the # root's class, the live type isn't importable from simvx.core. root_type_cls = getattr(root, "_script_original_class", None) or type(root) root_type_name = root_type_cls.__name__ if class_name is None: # Prefer the root node's `name` (e.g. "Arena") over the engine type # name ("Node3D") so the generated class doesn't shadow the imported # type. Fall back to "Scene" when the name isn't a valid identifier. candidate = root.name if candidate and candidate.isidentifier() and candidate != root_type_name: class_name = candidate else: class_name = "Scene" if not root_type_name else root_type_name ctx = _SourceContext() body_lines = ctx.emit_node(root, "self", is_root=True) # Build root kwargs for super().__init__() -- non-default properties on the root root_kwargs = ctx.build_root_kwargs(root) # Drop ``name=...`` when it matches the generated class name; the Node # default uses ``type(self).__name__`` which is already correct, and # emitting it would double-specify when callers pass ``name=`` themselves # (e.g. during scene deserialization). root_kwargs = [(k, v) for k, v in root_kwargs if not (k == "name" and v == repr(class_name))] super_kwargs_parts = ["**kwargs"] for k, v in root_kwargs: super_kwargs_parts.append(f"{k}={v}") super_args = ", ".join(super_kwargs_parts) # Build source parts: list[str] = [] # Imports if imports: parts.extend(imports) else: # Generate minimal imports all_types = sorted(ctx.used_types) if all_types: parts.append(f"from simvx.core import {', '.join(all_types)}") parts.append("") parts.append("") # Class definition base_type = root_type_name if base_type == class_name: # The class IS the root type -- use the immediate base class base_type = type(root).__mro__[1].__name__ if len(type(root).__mro__) > 1 else "Node" # Ensure the base type appears in imports (re-emit the import line so it covers this case too) if not imports and base_type and base_type not in ctx.used_types: ctx.used_types.add(base_type) all_types = sorted(ctx.used_types) parts[0] = f"from simvx.core import {', '.join(all_types)}" parts.append(f"class {class_name}({base_type}):") # __init__ parts.append(" def __init__(self, **kwargs):") parts.append(f" super().__init__({super_args})") if body_lines: parts.append("") for line in body_lines: parts.append(f" {line}") parts.append("") return "\n".join(parts)
class _SourceContext: """Tracks types used and generates code for a live node tree.""" def __init__(self): self.used_types: set[str] = set() self._var_counter: int = 0 self._seen_names: dict[str, int] = {} def _unique_var(self, name: str) -> str: """Generate a unique Python variable name.""" base = name.lower().replace(" ", "_").replace("-", "_") base = "".join(c for c in base if c.isalnum() or c == "_") if not base or base[0].isdigit(): base = f"node_{base}" if base == "self": base = "node_self" if base in self._seen_names: self._seen_names[base] += 1 return f"{base}_{self._seen_names[base]}" self._seen_names[base] = 0 return base def build_root_kwargs(self, node: Node) -> list[tuple[str, str]]: """Build kwargs for the root node (passed to super().__init__).""" return self._build_kwargs(node) def emit_node(self, node: Node, var_name: str, is_root: bool = False) -> list[str]: """Emit Python lines for a node and its children.""" lines: list[str] = [] # Use the pre-script type when a script has replaced the class. The # script class lives in a user script file, not simvx.core, and isn't # importable from there; the original structural type is the one # children/emit_node imports need. real_type = getattr(node, "_script_original_class", None) or type(node) node_type_name = real_type.__name__ # Don't add the root's own type to used_types — the generated class # definition handles it directly (as class_name or base_type). if not is_root: self.used_types.add(node_type_name) if not is_root: # Build kwargs for non-default values kwargs = self._build_kwargs(node) kwargs_str = ", ".join(f"{k}={v}" for k, v in kwargs) lines.append(f"{var_name} = {node_type_name}({kwargs_str})") # Script references — preserved so editor-attached scripts survive save target = var_name if var_name != "self" else "self" if getattr(node, "script", None): lines.append(f"{target}.script = {node.script!r}") if getattr(node, "_script_inline", None): lines.append(f'{target}._script_inline = {node._script_inline!r}') if getattr(node, "_script_embedded", None): # Use triple-quoted string for readable embedded source src = node._script_embedded.replace('"""', '\\"\\"\\"') lines.append(f'{target}._script_embedded = """{src}"""') # Children for child in node.children: child_var = self._unique_var(child.name) child_lines = self.emit_node(child, child_var) lines.extend(child_lines) lines.append(f"{target}.add_child({child_var})") return lines def _build_kwargs(self, node: Node) -> list[tuple[str, str]]: """Build a list of (key, value_repr) pairs for non-default constructor kwargs.""" kwargs: list[tuple[str, str]] = [] # Emit ``name`` when it differs from the constructed type's default. # The node may have been class-swapped by a script, so use the # pre-script (structural) type — that's the class the generated # source actually constructs. real_type = getattr(node, "_script_original_class", None) or type(node) if node.name != real_type.__name__: kwargs.append(("name", repr(node.name))) # Spatial properties for Node2D/Node3D kwargs.extend(self._spatial_kwargs(node)) # Property descriptors for prop_name, prop in node.get_properties().items(): val = getattr(node, prop_name) if _is_default(val, prop.default): continue formatted = _format_value(val) if formatted is None: continue # Non-serializable value -- skip self._register_value_type(val) kwargs.append((prop_name, formatted)) return kwargs def _spatial_kwargs(self, node: Node) -> list[tuple[str, str]]: """Build kwargs for spatial properties (position, rotation, scale).""" from .nodes_2d.node2d import Node2D from .nodes_3d.node3d import Node3D kwargs: list[tuple[str, str]] = [] if isinstance(node, Node3D): pos = node.position if not _is_zero_vec3(pos): self.used_types.add("Vec3") kwargs.append(("position", _format_value(Vec3(pos)))) rot = node.rotation if not _is_identity_quat(rot): self.used_types.add("Quat") kwargs.append(("rotation", _format_value(rot))) scl = node.scale if not _is_one_vec3(scl): self.used_types.add("Vec3") kwargs.append(("scale", _format_value(Vec3(scl)))) elif isinstance(node, Node2D): pos = node.position if not _is_zero_vec2(pos): self.used_types.add("Vec2") kwargs.append(("position", _format_value(Vec2(pos)))) rot = node.rotation if abs(rot) > 1e-9: kwargs.append(("rotation", _fmt_num(rot))) scl = node.scale if not _is_one_vec2(scl): self.used_types.add("Vec2") kwargs.append(("scale", _format_value(Vec2(scl)))) return kwargs def _register_value_type(self, val) -> None: """Track types needed for import generation.""" if isinstance(val, Vec2): self.used_types.add("Vec2") elif isinstance(val, Vec3): self.used_types.add("Vec3") elif isinstance(val, Quat): self.used_types.add("Quat") # ============================================================================ # Value formatting # ============================================================================ def _is_serializable(value) -> bool: """Check if a value can be serialized to Python source. Returns False for Node instances, callables, modules, and other non-representable types. """ if value is None: return True if isinstance(value, bool | int | float | str): return True if isinstance(value, Vec2 | Vec3 | Quat): return True if isinstance(value, np.ndarray): return True if isinstance(value, tuple | list): return all(_is_serializable(v) for v in value) if isinstance(value, dict): return all(isinstance(k, str) and _is_serializable(v) for k, v in value.items()) # Non-serializable: Node instances, callables, modules, etc. if isinstance(value, Node): return False if callable(value): return False return False def _format_value(val) -> str | None: """Format a value for Python source output. Returns None for non-serializable values (Node instances, callables, etc.). """ if not _is_serializable(val): return None if val is None: return "None" if isinstance(val, bool): return repr(val) if isinstance(val, Quat): return f"Quat({_fmt_num(val.w)}, {_fmt_num(val.x)}, {_fmt_num(val.y)}, {_fmt_num(val.z)})" if isinstance(val, Vec3): return f"Vec3({_fmt_num(val[0])}, {_fmt_num(val[1])}, {_fmt_num(val[2])})" if isinstance(val, Vec2): return f"Vec2({_fmt_num(val[0])}, {_fmt_num(val[1])})" if isinstance(val, np.ndarray): items = ", ".join(_fmt_num(float(v)) for v in val.flat) return f"[{items}]" if isinstance(val, int | float): return _fmt_num(val) if isinstance(val, str): return repr(val) if isinstance(val, tuple): items = ", ".join(_format_value(v) or "None" for v in val) return f"({items},)" if len(val) == 1 else f"({items})" if isinstance(val, list): items = ", ".join(_format_value(v) or "None" for v in val) return f"[{items}]" if isinstance(val, dict): items = ", ".join(f"{k!r}: {_format_value(v)}" for k, v in val.items()) return "{" + items + "}" return None def _fmt_num(v) -> str: """Format a number cleanly.""" if isinstance(v, int): return str(v) if isinstance(v, float | np.floating): fv = float(v) if fv == int(fv) and abs(fv) < 1e15: return f"{int(fv)}.0" return f"{fv:g}" return str(v) # ============================================================================ # Comparison helpers # ============================================================================ def _is_default(val, default) -> bool: """Check if a value equals its default.""" try: eq = val == default if isinstance(eq, np.ndarray): return bool(eq.all()) return bool(eq) except Exception: return False def _is_zero_vec2(v) -> bool: return abs(float(v[0])) < 1e-9 and abs(float(v[1])) < 1e-9 def _is_zero_vec3(v) -> bool: return abs(float(v[0])) < 1e-9 and abs(float(v[1])) < 1e-9 and abs(float(v[2])) < 1e-9 def _is_one_vec2(v) -> bool: return abs(float(v[0]) - 1.0) < 1e-9 and abs(float(v[1]) - 1.0) < 1e-9 def _is_one_vec3(v) -> bool: return abs(float(v[0]) - 1.0) < 1e-9 and abs(float(v[1]) - 1.0) < 1e-9 and abs(float(v[2]) - 1.0) < 1e-9 def _is_identity_quat(q) -> bool: return abs(float(q.w) - 1.0) < 1e-9 and abs(float(q.x)) < 1e-9 and abs(float(q.y)) < 1e-9 and abs(float(q.z)) < 1e-9