"""High-level scene-shaped editing surface for parso-parsed sources.
This is Tier 3b of the scene I/O layer. It composes the lossless parse
(:mod:`source_tree`) and prefix-preserving primitives (:mod:`edits`) with
greenfield emission (:mod:`emitter`) and structural detection
(:mod:`detection`) into a small public API:
SceneFile — a parsed Python file with byte-perfect round-trip save.
SceneClass — an editable view of one Node-subclass in the file.
ImportSet — an editable view of the file's top-level imports.
The editor's save/load path and the IDE's refactor tools build on this
surface. Lower tiers remain available for callers that need finer control.
"""
from __future__ import annotations
from pathlib import Path
from typing import TYPE_CHECKING
from parso.python.tree import Class
from parso.tree import Leaf, NodeOrLeaf
from . import edits as _edits
from .detection import primary_node_class_from_source
from .emitter import emit_scene
from .source_tree import SourceTree, parse_snippet, parse_source
if TYPE_CHECKING:
from ..node import Node
__all__ = ["ImportSet", "SceneClass", "SceneFile"]
# ---------------------------------------------------------------------------
# SceneFile
# ---------------------------------------------------------------------------
[docs]
class SceneFile:
"""A parsed Python scene file with byte-perfect round-trip save.
Holds a parso tree plus the original on-disk text. All edits operate on
the parso tree; :meth:`save` writes ``tree.get_code()``. Round-trip
identity is guaranteed when no edits were made.
"""
__slots__ = ("_source_tree", "_path", "_imports", "_original_text")
def __init__(self, source_tree: SourceTree, *, path: Path | None) -> None:
self._source_tree = source_tree
self._path = path
self._original_text = source_tree.original_text
self._imports = ImportSet(source_tree)
# -- constructors --------------------------------------------------------
[docs]
@classmethod
def load(cls, path: str | Path) -> SceneFile:
"""Read ``path`` and parse it.
Raises :class:`FileNotFoundError` when the file is absent;
:class:`parso.ParserSyntaxError` for malformed Python (we use
``error_recovery=False`` for explicit save targets so syntax issues
surface immediately rather than silently producing a partial tree).
"""
p = Path(path)
text = p.read_text()
source_tree = parse_source(text, error_recovery=False)
return cls(source_tree, path=p)
[docs]
@classmethod
def from_source(cls, text: str, *, path: Path | None = None) -> SceneFile:
"""Parse already-loaded source ``text``.
``path`` is recorded for :meth:`save` and error messages but is not
read.
"""
source_tree = parse_source(text, error_recovery=False)
return cls(source_tree, path=path)
[docs]
@classmethod
def from_runtime(cls, root: Node, *, class_name: str | None = None) -> SceneFile:
"""Greenfield: emit source for a live :class:`Node` tree, then parse it.
The returned :class:`SceneFile` has no ``path`` until :meth:`save`
is called with one.
"""
text = emit_scene(root, class_name=class_name)
return cls.from_source(text, path=None)
# -- properties ----------------------------------------------------------
[docs]
@property
def path(self) -> Path | None:
"""Path the file was loaded from / will be saved to, or ``None``."""
return self._path
[docs]
@property
def source_tree(self) -> SourceTree:
"""Underlying lossless :class:`SourceTree`."""
return self._source_tree
[docs]
@property
def imports(self) -> ImportSet:
"""Editable view of the file's top-level imports."""
return self._imports
# -- scene class lookup --------------------------------------------------
[docs]
def scene_class(self) -> SceneClass:
"""The single primary Node subclass in the file.
Raises :class:`AmbiguousSceneError` if the file contains multiple
Node subclasses; raises :class:`ValueError` if it contains none.
"""
text = self._source_tree.dump()
# `primary_node_class_from_source` raises AmbiguousSceneError on
# multiple matches; passes through (`None`) on zero matches.
name = primary_node_class_from_source(text, path=self._path)
if name is None:
where = f" in {self._path}" if self._path is not None else ""
raise ValueError(f"no Node subclass found{where}")
cls_node = self._source_tree.find_class(name)
if cls_node is None:
# Detection found a class via ast walk (which descends into
# nested scopes); the parso lookup is top-level only. Surface
# the discrepancy clearly rather than silently failing.
raise ValueError(
f"primary class {name!r} is not at module top-level; "
"scene classes must be defined at the top of the file"
)
return SceneClass(self, cls_node)
[docs]
def all_scene_classes(self) -> list[SceneClass]:
"""Every Node subclass defined in the file, in source order.
Used for diagnostics and IDE features. The typical scene has one.
Detection mirrors :func:`primary_node_class_from_source`'s rule
(Node base + ``__init__`` or class-body ``Property`` descriptors).
"""
import ast
from .detection import _is_scene_class
try:
ast_tree = ast.parse(self._source_tree.dump())
except SyntaxError:
return []
names: list[str] = []
for node in ast.iter_child_nodes(ast_tree):
if isinstance(node, ast.ClassDef) and _is_scene_class(node):
names.append(node.name)
out: list[SceneClass] = []
for name in names:
cls_node = self._source_tree.find_class(name)
if cls_node is not None:
out.append(SceneClass(self, cls_node))
return out
# -- top-level class insertion ------------------------------------------
[docs]
def insert_top_level_class(
self,
name: str,
base: str,
*,
body: str = "pass",
before: str | None = None,
) -> SceneClass:
"""Insert ``class <name>(<base>): <body>`` at module scope.
Auto-imports ``base`` via :class:`ImportSet` (defaults to
``simvx.core``). Placement is just before the existing scene class
(or before ``before`` when given) so the new definition sits between
imports and the scene that uses it. Returns the new
:class:`SceneClass` view.
Raises :class:`ValueError` if a top-level class with the same name
already exists.
"""
if self._source_tree.find_class(name) is not None:
raise ValueError(f"top-level class {name!r} already exists")
self._imports.ensure(base, from_="simvx.core")
snippet_text = f"class {name}({base}):\n {body}\n"
new_class = parse_snippet(snippet_text)
new_class.parent = None
module = self._source_tree.module
children = module.children
anchor_name = before
if anchor_name is None:
existing = list(self._source_tree.iter_classes())
if existing:
anchor_name = existing[0].name.value
if anchor_name is not None:
anchor_class = self._source_tree.find_class(anchor_name)
if anchor_class is None:
raise ValueError(f"anchor class {anchor_name!r} not found")
# Find the wrapping top-level child (classdef or decorated).
target = anchor_class
while target.parent is not None and target.parent is not module:
target = target.parent
idx = children.index(target)
anchor_prefix = target.get_first_leaf().prefix
# Inserted class inherits the anchor's prefix (preserving any
# leading blank line the user had above it); the anchor moves
# down with a single-newline separator.
new_class.get_first_leaf().prefix = anchor_prefix
target.get_first_leaf().prefix = "\n\n"
new_class.parent = module
children.insert(idx, new_class)
else:
# No existing classes: append at end (before endmarker).
endmarker_idx = next(
(i for i, c in enumerate(children) if c.type == "endmarker"),
len(children),
)
# Pad with one blank line if there's prior content.
new_class.get_first_leaf().prefix = "\n\n" if endmarker_idx > 0 else ""
new_class.parent = module
children.insert(endmarker_idx, new_class)
# Resolve the freshly inserted SceneClass.
cls_node = self._source_tree.find_class(name)
if cls_node is None:
raise RuntimeError(f"insert_top_level_class: failed to locate {name!r} after insertion")
return SceneClass(self, cls_node)
# -- serialisation -------------------------------------------------------
[docs]
def dump(self) -> str:
"""Current source as a string."""
return self._source_tree.dump()
[docs]
def is_dirty(self) -> bool:
"""True iff :meth:`dump` differs from the original input text.
Used by :class:`SceneModule` to skip writes for files that were
opened but never edited.
"""
return self.dump() != self._original_text
[docs]
def save(self, path: str | Path | None = None) -> Path:
"""Write to ``path`` (or to ``self.path`` if not given).
Returns the path written. Raises :class:`ValueError` if neither is
set. ``self._path`` is updated on success so subsequent saves
without an argument reuse the last destination.
"""
target = Path(path) if path is not None else self._path
if target is None:
raise ValueError("save() requires a path: SceneFile has no recorded path")
text = self.dump()
target.write_text(text)
self._path = target
self._original_text = text
return target
[docs]
def assert_idempotent(self) -> None:
"""Assert that :meth:`dump` equals the original input text.
Useful for tests that verify no accidental edits leaked into a
load/save round-trip.
"""
current = self.dump()
if current != self._original_text:
import difflib
diff = "".join(
difflib.unified_diff(
self._original_text.splitlines(keepends=True),
current.splitlines(keepends=True),
fromfile="original",
tofile="current",
)
)
raise AssertionError(f"SceneFile is not idempotent:\n{diff}")
# ---------------------------------------------------------------------------
# SceneClass
# ---------------------------------------------------------------------------
[docs]
class SceneClass:
"""An editable view of one Node-subclass class definition inside a scene."""
__slots__ = ("_file", "_class")
def __init__(self, file: SceneFile, class_node: Class) -> None:
self._file = file
self._class = class_node
# -- identity ------------------------------------------------------------
[docs]
@property
def name(self) -> str:
"""Class name (e.g. ``"Arena"``)."""
return self._class.name.value
[docs]
@property
def node(self) -> Class:
"""Underlying parso :class:`Class` node."""
return self._class
# -- class-level Property descriptors ------------------------------------
[docs]
def has_property(self, name: str) -> bool:
return self._find_class_property(name) is not None
[docs]
def get_property_default(self, name: str) -> str | None:
"""Source text of the Property's default expression, or ``None``.
Inherited Properties are not visible — use the runtime tree to
observe inherited values.
"""
node = self._find_class_property(name)
if node is None:
return None
# node is the expr_stmt: name = atom_expr(Property(...))
atom_expr = self._property_atom_expr(node)
trailer = atom_expr.children[-1]
first_arg = _first_positional_value(trailer)
if first_arg is None:
return None
return first_arg.get_code().strip()
[docs]
def add_property(self, name: str, default_expr: str) -> None:
"""Insert ``name = Property(default_expr)`` into the class body.
Inserted after existing class-level Property declarations. Auto-
imports ``Property`` via the file's :class:`ImportSet`.
"""
if self.has_property(name):
raise ValueError(f"property {name!r} already declared on {self.name}")
self._file.imports.ensure("Property", from_="simvx.core")
suite = self._class_suite()
# Find the last existing Property simple_stmt to anchor after; if
# none, anchor after the class header (suite's first NEWLINE).
last_prop = self._last_class_property_stmt()
snippet = parse_snippet(f"{name} = Property({default_expr})\n")
if last_prop is not None:
indent = _edits._indent_of(_get_prefix(last_prop)) # type: ignore[attr-defined]
_set_prefix(snippet, indent)
snippet.parent = suite
children = suite.children
children.insert(children.index(last_prop) + 1, snippet)
else:
anchor = self._suite_anchor(suite)
_edits.insert_after(anchor, snippet)
[docs]
def remove_property(self, name: str) -> None:
node = self._find_class_property(name)
if node is None:
raise ValueError(f"property {name!r} not declared on {self.name}")
_edits.remove_node(node)
[docs]
def set_property_default(self, name: str, default_expr: str) -> None:
node = self._find_class_property(name)
if node is None:
raise ValueError(f"property {name!r} not declared on {self.name}")
atom_expr = self._property_atom_expr(node)
trailer = atom_expr.children[-1]
old_value = _first_positional_value(trailer)
new_value = parse_snippet(default_expr)
if old_value is None:
# Property() with no default — append as positional.
new_value.parent = None
_set_prefix(new_value, "")
# parens: trailer.children = [(, ... ,)]; insert before close.
trailer.children.insert(-1, new_value)
new_value.parent = trailer
return
_edits.replace_node(old_value, new_value, preserve_prefix=True)
# -- root super().__init__ kwargs ---------------------------------------
[docs]
def get_root_kwarg(self, name: str) -> str | None:
trailer = self._super_init_trailer()
if trailer is None:
return None
val = _edits.get_call_kwarg(trailer, name)
return val.get_code().strip() if val is not None else None
[docs]
def set_root_kwarg(self, name: str, value_expr: str) -> None:
"""Update or insert a kwarg in the root ``super().__init__(...)`` call."""
trailer = self._super_init_trailer()
if trailer is None:
raise ValueError(f"{self.name}.__init__ has no super().__init__() call to edit")
_edits.set_call_kwarg(trailer, name, value_expr)
[docs]
def remove_root_kwarg(self, name: str) -> None:
trailer = self._super_init_trailer()
if trailer is None:
raise ValueError(f"{self.name}.__init__ has no super().__init__() call")
for arg in _arglist_arguments(trailer):
if _argument_name(arg) == name:
_remove_argument(arg)
return
raise ValueError(f"kwarg {name!r} not found on super().__init__()")
# -- children: assignments + add_child calls ---------------------------
[docs]
def has_child(self, var_name: str) -> bool:
return self._find_child_assignment(var_name) is not None
[docs]
def child_var_names(self) -> list[str]:
"""Variable names of all children added via ``self.add_child(<var>)``,
in source order."""
out: list[str] = []
for stmt in self._init_body_stmts():
var = _add_child_var_name(stmt)
if var is not None:
out.append(var)
return out
[docs]
def add_child(
self,
var_name: str,
type_name: str,
*,
before: str | None = None,
after: str | None = None,
from_module: str = "simvx.core",
**kwarg_exprs: str,
) -> None:
"""Insert a child construction + ``self.add_child`` pair into ``__init__``.
Position: appended at the end of the existing child block by default;
``before=`` or ``after=`` (mutually exclusive) places relative to
another child.
Auto-imports ``type_name`` from ``from_module`` (defaults to
``simvx.core``) when the name is not already imported under any
alias. Callers placing user classes should pass
``from_module=type(node).__module__``.
Raises :class:`ValueError` if ``var_name`` already exists in the
``__init__`` body or if both ``before`` and ``after`` are passed.
Note: when the source is procedural (children built inside loops or
conditionals — see :func:`has_procedural_construction`), the inserted
statements are appended at the top level of ``__init__`` and may
execute in a surprising order relative to the procedural code.
"""
if before is not None and after is not None:
raise ValueError("add_child: pass at most one of `before` or `after`")
if self.has_child(var_name):
raise ValueError(f"child variable {var_name!r} already exists in {self.name}.__init__")
self._file.imports.ensure(type_name, from_=from_module)
kwargs_str = ", ".join(f"{k}={v}" for k, v in kwarg_exprs.items())
assignment_src = f"{var_name} = {type_name}({kwargs_str})\n"
add_child_src = f"self.add_child({var_name})\n"
assignment = parse_snippet(assignment_src)
add_child = parse_snippet(add_child_src)
if before is not None:
anchor = self._find_child_assignment(before)
if anchor is None:
raise ValueError(f"child {before!r} not found")
self._insert_pair_before(anchor, assignment, add_child)
elif after is not None:
anchor_add_child = self._find_add_child_call(after)
if anchor_add_child is None:
raise ValueError(f"child {after!r} not found")
self._insert_pair_after(anchor_add_child, assignment, add_child)
else:
# Append after the last existing add_child call, or after
# super().__init__() if none.
anchor = self._last_add_child_stmt() or self._super_init_stmt()
if anchor is None:
anchor = self._suite_anchor(self._init_suite())
self._insert_pair_after(anchor, assignment, add_child)
[docs]
def remove_child(self, var_name: str) -> None:
"""Remove the assignment line and the ``self.add_child`` line.
Does not auto-clean unused imports — that is the caller's
responsibility (use :meth:`ImportSet.remove`). Raises
:class:`ValueError` if ``var_name`` is absent.
"""
assignment = self._find_child_assignment(var_name)
add_child = self._find_add_child_call(var_name)
if assignment is None and add_child is None:
raise ValueError(f"child {var_name!r} not found in {self.name}.__init__")
if assignment is not None:
_edits.remove_node(assignment)
if add_child is not None:
_edits.remove_node(add_child)
[docs]
def rename_child(self, old: str, new: str) -> None:
if self.has_child(new):
raise ValueError(f"cannot rename {old!r}: target name {new!r} already exists")
assignment = self._find_child_assignment(old)
add_child = self._find_add_child_call(old)
if assignment is None or add_child is None:
raise ValueError(f"child {old!r} not found in {self.name}.__init__")
# The first child of an expr_stmt assignment is the target name leaf.
target_leaf = assignment.children[0].children[0]
target_leaf.value = new
# The add_child call: walk to find the name leaf inside the trailer.
for leaf in _iter_leaves(add_child):
if leaf.type == "name" and leaf.value == old:
leaf.value = new
break
[docs]
def get_child_kwarg(self, var_name: str, kwarg: str) -> str | None:
ctor_trailer = self._child_ctor_trailer(var_name)
if ctor_trailer is None:
return None
val = _edits.get_call_kwarg(ctor_trailer, kwarg)
return val.get_code().strip() if val is not None else None
[docs]
def set_child_kwarg(self, var_name: str, kwarg: str, value_expr: str) -> None:
ctor_trailer = self._child_ctor_trailer(var_name)
if ctor_trailer is None:
raise ValueError(f"child {var_name!r} not found")
_edits.set_call_kwarg(ctor_trailer, kwarg, value_expr)
[docs]
def remove_child_kwarg(self, var_name: str, kwarg: str) -> None:
ctor_trailer = self._child_ctor_trailer(var_name)
if ctor_trailer is None:
raise ValueError(f"child {var_name!r} not found")
for arg in _arglist_arguments(ctor_trailer):
if _argument_name(arg) == kwarg:
_remove_argument(arg)
return
raise ValueError(f"kwarg {kwarg!r} not found on child {var_name!r}")
[docs]
def reorder_children(self, order: list[str]) -> None:
"""Reorder child assignment + ``add_child`` line pairs to match ``order``.
Every existing child must appear in ``order`` exactly once.
"""
existing = self.child_var_names()
if sorted(existing) != sorted(order) or len(existing) != len(order):
raise ValueError(f"reorder_children: order {order!r} does not match existing children {existing!r}")
if existing == order:
return
# Snapshot the (assignment, add_child) pairs and the prefix of the
# first one (which we will keep on whatever ends up first).
pairs: dict[str, tuple[NodeOrLeaf, NodeOrLeaf]] = {}
for var in existing:
assignment = self._find_child_assignment(var)
add_child = self._find_add_child_call(var)
if assignment is None or add_child is None:
raise ValueError(f"child {var!r} structurally inconsistent during reorder")
pairs[var] = (assignment, add_child)
first_assignment, _ = pairs[existing[0]]
first_prefix = _get_prefix(first_assignment)
# Anchor on the statement immediately before the first child.
suite = self._init_suite()
children = suite.children
first_idx = children.index(first_assignment)
anchor = children[first_idx - 1]
# Detach all pairs from the suite (in reverse to keep indices stable).
for var in reversed(existing):
assignment, add_child = pairs[var]
children.remove(add_child)
add_child.parent = None
children.remove(assignment)
assignment.parent = None
# Re-insert in `order`, after the anchor. Pack tightly: every line
# gets the original indent only, no extra blank lines.
indent = _edits._indent_of(first_prefix) # type: ignore[attr-defined]
cursor = anchor
for var in order:
assignment, add_child = pairs[var]
_set_prefix(assignment, indent)
_set_prefix(add_child, indent)
assignment.parent = suite
add_child.parent = suite
cursor_idx = children.index(cursor)
children.insert(cursor_idx + 1, assignment)
children.insert(cursor_idx + 2, add_child)
cursor = add_child
# Restore the original prefix (including blank lines / comments)
# onto whatever now sits at the head of the block.
new_first_assignment = pairs[order[0]][0]
_set_prefix(new_first_assignment, first_prefix)
# -- internal helpers ----------------------------------------------------
def _class_suite(self) -> NodeOrLeaf:
return self._class.children[-1]
def _init_funcdef(self) -> NodeOrLeaf | None:
suite = self._class_suite()
for child in suite.children:
if child.type == "funcdef":
name_leaf = child.children[1]
if isinstance(name_leaf, Leaf) and name_leaf.value == "__init__":
return child
return None
def _init_suite(self) -> NodeOrLeaf:
funcdef = self._init_funcdef()
if funcdef is None:
raise ValueError(f"{self.name} has no __init__")
return funcdef.children[-1]
def _suite_anchor(self, suite: NodeOrLeaf) -> NodeOrLeaf:
# The first child of a suite is the introductory NEWLINE.
return suite.children[0]
def _init_body_stmts(self) -> list[NodeOrLeaf]:
try:
suite = self._init_suite()
except ValueError:
return []
return [c for c in suite.children if c.type == "simple_stmt"]
def _find_class_property(self, name: str) -> NodeOrLeaf | None:
suite = self._class_suite()
for stmt in suite.children:
if stmt.type != "simple_stmt":
continue
atom_expr = self._property_atom_expr(stmt)
if atom_expr is None:
continue
target_leaf = stmt.children[0].children[0]
if isinstance(target_leaf, Leaf) and target_leaf.value == name:
return stmt
return None
def _last_class_property_stmt(self) -> NodeOrLeaf | None:
last = None
suite = self._class_suite()
for stmt in suite.children:
if stmt.type != "simple_stmt":
continue
if self._property_atom_expr(stmt) is not None:
last = stmt
return last
@staticmethod
def _property_atom_expr(stmt: NodeOrLeaf) -> NodeOrLeaf | None:
"""If ``stmt`` is ``<name> = Property(...)``, return the call atom_expr;
otherwise None."""
if stmt.type != "simple_stmt" or not stmt.children:
return None
expr_stmt = stmt.children[0]
if expr_stmt.type != "expr_stmt":
return None
if len(expr_stmt.children) < 3:
return None
target, eq, value = expr_stmt.children[0], expr_stmt.children[1], expr_stmt.children[2]
if target.type != "name" or getattr(eq, "type", None) != "operator" or eq.value != "=":
return None
if value.type != "atom_expr" or not value.children:
return None
head = value.children[0]
if head.type != "name" or head.value != "Property":
return None
return value
def _super_init_stmt(self) -> NodeOrLeaf | None:
for stmt in self._init_body_stmts():
if _is_super_init_stmt(stmt):
return stmt
return None
def _super_init_trailer(self) -> NodeOrLeaf | None:
stmt = self._super_init_stmt()
if stmt is None:
return None
atom_expr = stmt.children[0]
return atom_expr.children[-1]
def _find_child_assignment(self, var_name: str) -> NodeOrLeaf | None:
for stmt in self._init_body_stmts():
name = _assignment_target_name(stmt)
if name == var_name and not _is_super_init_stmt(stmt):
return stmt
return None
def _find_add_child_call(self, var_name: str) -> NodeOrLeaf | None:
for stmt in self._init_body_stmts():
if _add_child_var_name(stmt) == var_name:
return stmt
return None
def _last_add_child_stmt(self) -> NodeOrLeaf | None:
last = None
for stmt in self._init_body_stmts():
if _add_child_var_name(stmt) is not None:
last = stmt
return last
def _insert_pair_after(self, anchor: NodeOrLeaf, assignment: NodeOrLeaf, add_child: NodeOrLeaf) -> None:
"""Insert ``assignment`` then ``add_child`` directly after ``anchor``.
Both inserted statements sit at ``anchor``'s indent with no extra
blank line — children pack tightly like the corpus convention.
"""
suite = self._init_suite()
children = suite.children
idx = children.index(anchor)
indent = _edits._indent_of(_get_prefix(anchor)) # type: ignore[attr-defined]
_set_prefix(assignment, indent)
_set_prefix(add_child, indent)
assignment.parent = suite
add_child.parent = suite
children.insert(idx + 1, assignment)
children.insert(idx + 2, add_child)
def _insert_pair_before(self, anchor: NodeOrLeaf, assignment: NodeOrLeaf, add_child: NodeOrLeaf) -> None:
"""Insert ``assignment`` then ``add_child`` directly before ``anchor``.
Inserted lines inherit ``anchor``'s indent; the original ``anchor``
prefix (which may carry a leading blank line / comment) is kept on
``anchor`` itself.
"""
suite = self._init_suite()
children = suite.children
idx = children.index(anchor)
indent = _edits._indent_of(_get_prefix(anchor)) # type: ignore[attr-defined]
_set_prefix(assignment, indent)
_set_prefix(add_child, indent)
assignment.parent = suite
add_child.parent = suite
children.insert(idx, assignment)
children.insert(idx + 1, add_child)
def _child_ctor_trailer(self, var_name: str) -> NodeOrLeaf | None:
assignment = self._find_child_assignment(var_name)
if assignment is None:
return None
expr_stmt = assignment.children[0]
if expr_stmt.type != "expr_stmt" or len(expr_stmt.children) < 3:
return None
rhs = expr_stmt.children[2]
if rhs.type != "atom_expr":
return None
last = rhs.children[-1]
if last.type == "trailer" and last.children and last.children[0].value == "(":
return last
return None
# ---------------------------------------------------------------------------
# ImportSet
# ---------------------------------------------------------------------------
[docs]
class ImportSet:
"""Editable view of the file's top-level imports."""
__slots__ = ("_source_tree",)
def __init__(self, source_tree: SourceTree) -> None:
self._source_tree = source_tree
[docs]
def has(self, name: str, *, from_: str | None = None) -> bool:
for from_module, imported_name in self.names():
if imported_name == name and from_module == from_:
return True
return False
[docs]
def has_any_alias(self, name: str) -> bool:
"""True iff ``name`` is imported from anywhere (any module)."""
for _from_module, imported_name in self.names():
if imported_name == name:
return True
return False
[docs]
def ensure(self, name: str, *, from_: str | None = None) -> None:
"""Add ``import name`` or ``from <from_> import name`` if absent.
When ``from_`` matches an existing ``from <from_> import …`` line,
the new name is merged into that line (sorted, deduplicated)
instead of creating a separate import line.
"""
if self.has(name, from_=from_):
return
if from_ is not None:
# Try to merge into an existing "from <from_> import ..." line.
existing = self._find_import_from(from_)
if existing is not None:
self._merge_into_import_from(existing, name)
return
snippet_text = f"from {from_} import {name}\n" if from_ is not None else f"import {name}\n"
new_stmt = parse_snippet(snippet_text)
self._insert_import_line(new_stmt)
[docs]
def remove(self, name: str, *, from_: str | None = None) -> None:
"""Remove an import. If the line becomes empty, remove it.
No-op if ``name`` is not imported (with the given ``from_``).
"""
if from_ is None:
stmt = self._find_plain_import(name)
if stmt is not None:
_edits.remove_node(stmt)
return
import_from = self._find_import_from(from_)
if import_from is None:
return
names_node = self._import_from_names_node(import_from)
if names_node is None:
return
if names_node.type == "name":
if names_node.value == name:
# Sole imported name → remove the whole simple_stmt line.
stmt = self._stmt_for_import(import_from)
if stmt is not None:
_edits.remove_node(stmt)
return
# `import_as_names` — children are alternating `name` / `,`.
children = names_node.children
for idx, child in enumerate(children):
if child.type == "name" and child.value == name:
# Remove this name and one neighbouring comma so the list
# stays well-formed. If we're removing the first name, the
# second name needs to inherit the original first's prefix
# so the leading space after ``import`` is preserved.
if idx == 0 and len(children) >= 3 and children[1].type == "operator":
# Remove name + comma; let next name keep its own prefix
# but copy the leading space.
leading = child.prefix
del children[0:2]
if children and isinstance(children[0], Leaf):
children[0].prefix = leading
elif idx + 1 < len(children) and children[idx + 1].type == "operator":
del children[idx : idx + 2]
elif idx > 0 and children[idx - 1].type == "operator":
del children[idx - 1 : idx + 1]
else:
del children[idx]
# If only one name remains, collapse import_as_names into a
# bare name (matches parso's parsed shape for single-name
# imports).
remaining_names = [c for c in children if c.type == "name"]
if len(remaining_names) == 1 and len(children) == 1:
sole = children[0]
sole_idx = import_from.children.index(names_node)
sole.parent = import_from
import_from.children[sole_idx] = sole
if not remaining_names:
stmt = self._stmt_for_import(import_from)
if stmt is not None:
_edits.remove_node(stmt)
return
[docs]
def names(self) -> list[tuple[str | None, str]]:
"""List of ``(from_, name)`` pairs in source order."""
out: list[tuple[str | None, str]] = []
for imp in self._source_tree.iter_imports():
if imp.type == "import_name":
# `import_name` -> [keyword, target]
# target ∈ {name, dotted_name, dotted_as_name, dotted_as_names}
target = imp.children[1]
if target.type == "name":
out.append((None, target.value))
elif target.type == "dotted_as_name":
# [name, 'as', alias] — record the original name.
out.append((None, target.children[0].value))
elif target.type == "dotted_as_names":
for sub in target.children:
if sub.type == "name":
out.append((None, sub.value))
elif sub.type == "dotted_as_name":
out.append((None, sub.children[0].value))
elif target.type == "dotted_name":
out.append((None, target.children[0].value))
elif imp.type == "import_from":
from_module = self._import_from_module_name(imp)
names_node = self._import_from_names_node(imp)
if names_node is None:
continue
if names_node.type == "name":
out.append((from_module, names_node.value))
elif names_node.type == "import_as_names":
for sub in names_node.children:
if sub.type == "name":
out.append((from_module, sub.value))
elif sub.type == "import_as_name":
out.append((from_module, sub.children[0].value))
return out
# -- internal helpers ----------------------------------------------------
def _stmt_for_import(self, import_node: NodeOrLeaf) -> NodeOrLeaf | None:
"""The wrapping ``simple_stmt`` for an import_name/import_from."""
return import_node.parent
def _find_plain_import(self, name: str) -> NodeOrLeaf | None:
for imp in self._source_tree.iter_imports():
if imp.type != "import_name":
continue
target = imp.children[1]
if target.type == "name" and target.value == name:
return self._stmt_for_import(imp)
return None
def _find_import_from(self, module: str) -> NodeOrLeaf | None:
for imp in self._source_tree.iter_imports():
if imp.type != "import_from":
continue
if self._import_from_module_name(imp) == module:
return imp
return None
@staticmethod
def _import_from_module_name(import_from: NodeOrLeaf) -> str:
# children: [keyword 'from', <leading-dots>?, <module-node>?, keyword 'import', names_node]
# Relative imports prefix the module with one or more `.` operators,
# and `from . import x` has no module node at all.
children = import_from.children
parts: list[str] = []
for c in children[1:]:
if c.type == "keyword" and c.value == "import":
break
if c.type == "operator" and c.value == ".":
parts.append(".")
elif c.type == "name":
parts.append(c.value)
elif c.type == "dotted_name":
parts.append("".join(sub.value for sub in c.children if hasattr(sub, "value")))
return "".join(parts)
@staticmethod
def _import_from_names_node(import_from: NodeOrLeaf) -> NodeOrLeaf | None:
# Find the child after the 'import' keyword.
children = import_from.children
for i, c in enumerate(children):
if c.type == "keyword" and c.value == "import":
# Names node is next non-paren child.
for nxt in children[i + 1 :]:
if nxt.type == "operator" and nxt.value == "(":
continue
if nxt.type == "operator" and nxt.value == ")":
continue
return nxt
return None
def _merge_into_import_from(self, import_from: NodeOrLeaf, new_name: str) -> None:
"""Merge ``new_name`` into ``import_from``'s name list, sorted + dedup."""
names_node = self._import_from_names_node(import_from)
if names_node is None:
return
if names_node.type == "name":
# Bare single name → upgrade to import_as_names with the new name.
existing_names = [names_node.value]
if new_name in existing_names:
return
sorted_names = sorted({*existing_names, new_name})
new_names_src = ", ".join(sorted_names)
new_snippet = parse_snippet(f"from x import {new_names_src}\n")
# Lift the new import_as_names node from the snippet.
new_import_from = new_snippet.children[0]
new_names_node = self._import_from_names_node(new_import_from)
if new_names_node is None:
return
new_names_node.parent = import_from
old_idx = import_from.children.index(names_node)
# Preserve the original prefix on the names slot.
old_prefix = names_node.get_first_leaf().prefix
new_names_node.get_first_leaf().prefix = old_prefix
import_from.children[old_idx] = new_names_node
return
# import_as_names case
children = names_node.children
existing_names = [c.value for c in children if c.type == "name"]
if new_name in existing_names:
return
sorted_names = sorted({*existing_names, new_name})
new_names_src = ", ".join(sorted_names)
new_snippet = parse_snippet(f"from x import {new_names_src}\n")
new_import_from = new_snippet.children[0]
new_names_node = self._import_from_names_node(new_import_from)
if new_names_node is None:
return
# Preserve old leading prefix (between 'import ' and first name) and
# any trailing comma if originally present.
old_prefix = children[0].prefix if children and isinstance(children[0], Leaf) else ""
had_trailing_comma = children and children[-1].type == "operator" and children[-1].value == ","
new_children = list(new_names_node.children)
if had_trailing_comma:
new_children.append(_edits._make_op(",", prefix="")) # type: ignore[attr-defined]
# Set prefix on first child of new_children
if new_children and isinstance(new_children[0], Leaf):
new_children[0].prefix = old_prefix
names_node.children[:] = new_children
for c in new_children:
c.parent = names_node
def _insert_import_line(self, new_stmt: NodeOrLeaf) -> None:
"""Insert a new import simple_stmt at the bottom of the existing
import block (packed tightly), or at the top of the module if no
imports exist."""
module = self._source_tree.module
last_import_stmt: NodeOrLeaf | None = None
for imp in self._source_tree.iter_imports():
stmt = imp.parent
if stmt is not None and stmt.parent is module:
last_import_stmt = stmt
if last_import_stmt is not None:
children = module.children
idx = children.index(last_import_stmt)
# Pack tightly: top-level statements have empty prefix and a
# trailing newline in their content, so an empty prefix is the
# correct place-on-the-next-line marker.
_set_prefix(new_stmt, "")
new_stmt.parent = module
children.insert(idx + 1, new_stmt)
return
# No existing imports: place right at the top of the module, before
# the first non-newline child. Preserve the first child's prefix on
# itself (typical case: a docstring or class).
top_children = module.children
first_idx = 0
while first_idx < len(top_children) and top_children[first_idx].type == "newline":
first_idx += 1
if first_idx >= len(top_children):
return
anchor = top_children[first_idx]
original_prefix = _get_prefix(anchor)
_set_prefix(new_stmt, original_prefix)
_set_prefix(anchor, "")
new_stmt.parent = module
top_children.insert(first_idx, new_stmt)
# ---------------------------------------------------------------------------
# Module-level helpers for SceneClass
# ---------------------------------------------------------------------------
def _is_super_init_stmt(stmt: NodeOrLeaf) -> bool:
"""True iff ``stmt`` is ``super().__init__(...)``."""
if stmt.type != "simple_stmt" or not stmt.children:
return False
inner = stmt.children[0]
if inner.type != "atom_expr" or not inner.children:
return False
head = inner.children[0]
if head.type != "name" or head.value != "super":
return False
# Expect: super, trailer(()), trailer(.__init__), trailer((...)).
if len(inner.children) < 4:
return False
return True
def _assignment_target_name(stmt: NodeOrLeaf) -> str | None:
"""For a ``simple_stmt`` whose inner is ``<name> = …`` return the name."""
if stmt.type != "simple_stmt" or not stmt.children:
return None
expr = stmt.children[0]
if expr.type != "expr_stmt" or len(expr.children) < 3:
return None
target, eq = expr.children[0], expr.children[1]
if target.type != "name":
return None
if getattr(eq, "type", None) != "operator" or eq.value != "=":
return None
return target.value
def _add_child_var_name(stmt: NodeOrLeaf) -> str | None:
"""If ``stmt`` is ``self.add_child(<var>)`` return ``<var>`` else None."""
if stmt.type != "simple_stmt" or not stmt.children:
return None
inner = stmt.children[0]
if inner.type != "atom_expr" or len(inner.children) < 3:
return None
head = inner.children[0]
if head.type != "name" or head.value != "self":
return None
dot_trailer = inner.children[1]
if dot_trailer.type != "trailer" or len(dot_trailer.children) < 2:
return None
if dot_trailer.children[0].value != ".":
return None
method_name = dot_trailer.children[1]
if getattr(method_name, "value", None) != "add_child":
return None
call_trailer = inner.children[2]
if call_trailer.type != "trailer" or len(call_trailer.children) < 3:
return None
if call_trailer.children[0].value != "(" or call_trailer.children[-1].value != ")":
return None
inner_args = call_trailer.children[1:-1]
if len(inner_args) != 1:
return None
arg = inner_args[0]
if arg.type != "name":
return None
return arg.value
def _arglist_arguments(trailer: NodeOrLeaf):
"""Yield ``argument`` nodes inside a call trailer.
Mirrors :func:`edits._iter_arguments` but kept local to scene_file so
SceneClass operations don't depend on private edits helpers.
"""
inner = trailer.children[1:-1]
if not inner:
return
if len(inner) == 1:
node = inner[0]
if node.type == "argument":
yield node
return
if node.type == "arglist":
for c in node.children:
if c.type == "argument":
yield c
return
return
for c in inner:
if c.type == "argument":
yield c
def _argument_name(arg: NodeOrLeaf) -> str | None:
if not hasattr(arg, "children") or len(arg.children) < 3:
return None
name_node, eq = arg.children[0], arg.children[1]
if name_node.type != "name" or getattr(eq, "type", None) != "operator" or eq.value != "=":
return None
return name_node.value
def _remove_argument(arg: NodeOrLeaf) -> None:
"""Remove an ``argument`` from its enclosing arglist or single-arg trailer."""
parent = arg.parent
if parent is None:
raise ValueError("_remove_argument: argument has no parent")
if parent.type == "arglist":
children = parent.children
idx = children.index(arg)
# Remove the argument and one neighbouring comma to keep arglist
# well-formed.
if idx + 1 < len(children) and children[idx + 1].type == "operator" and children[idx + 1].value == ",":
del children[idx : idx + 2]
elif idx > 0 and children[idx - 1].type == "operator" and children[idx - 1].value == ",":
del children[idx - 1 : idx + 1]
else:
del children[idx]
# If only one argument remains, collapse arglist back to that arg.
remaining_args = [c for c in children if c.type == "argument"]
if len(remaining_args) == 1 and len(children) == 1:
sole = children[0]
grandparent = parent.parent
if grandparent is not None:
gp_children = grandparent.children
gp_idx = gp_children.index(parent)
sole.parent = grandparent
sole.get_first_leaf().prefix = parent.get_first_leaf().prefix
gp_children[gp_idx] = sole
return
# Single-argument trailer: parent is the trailer.
if parent.type == "trailer":
children = parent.children
idx = children.index(arg)
del children[idx]
return
raise ValueError(f"_remove_argument: unexpected parent {parent.type}")
def _first_positional_value(trailer: NodeOrLeaf) -> NodeOrLeaf | None:
"""First positional argument value inside a call trailer, else None."""
inner = trailer.children[1:-1]
if not inner:
return None
if len(inner) == 1:
node = inner[0]
if node.type == "argument":
# Could be keyword; treat name=value as not positional.
if len(node.children) >= 3 and getattr(node.children[1], "value", None) == "=":
return None
return node
if node.type == "arglist":
for c in node.children:
if c.type == "argument":
if len(c.children) >= 3 and getattr(c.children[1], "value", None) == "=":
return None
return c
if c.type == "operator" and c.value == ",":
continue
return c
return None
return node
return None
def _iter_leaves(node: NodeOrLeaf):
if isinstance(node, Leaf):
yield node
return
for c in node.children:
yield from _iter_leaves(c)
def _get_prefix(node: NodeOrLeaf) -> str:
if isinstance(node, Leaf):
return node.prefix
return node.get_first_leaf().prefix
def _set_prefix(node: NodeOrLeaf, prefix: str) -> None:
if isinstance(node, Leaf):
node.prefix = prefix
else:
node.get_first_leaf().prefix = prefix