Source code for simvx.core.scene_io.symbols

"""Pure CST queries over a parso tree for class definitions and use sites.

This module is the foundation for cross-file class rename. It is stateless
and does no I/O — all functions take a :class:`SourceTree` and walk the
parso CST. ``rename_class_in_source`` mutates the tree in place; callers
:meth:`SourceTree.dump` the result.

Use-site classification is purely syntactic (no name resolution): an aliased
import does not propagate to its alias, and a local-variable shadow inside
a function suppresses uses of the class name within that function.
"""

from __future__ import annotations

from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Literal

from parso.python.tree import Class
from parso.tree import Leaf, NodeOrLeaf

from .source_tree import SourceTree

__all__ = [
    "ClassDefRef",
    "UseSiteRef",
    "find_class_definitions",
    "find_class_uses",
    "rename_class_in_source",
    "rename_module_in_imports",
]


UseKind = Literal[
    "import",
    "import_alias",
    "base_class",
    "instantiation",
    "isinstance_arg",
    "annotation",
    "bare_reference",
]


[docs] @dataclass(frozen=True) class ClassDefRef: """A top-level class definition discovered in a source tree.""" name: str bases: tuple[str, ...] line: int classdef_node: NodeOrLeaf
[docs] @dataclass(frozen=True) class UseSiteRef: """A site where a class name is referenced. ``leaf`` is the bare-name leaf to potentially rewrite. For aliased imports (`from x import Foo as Bar`) ``leaf`` points at the source name ``Foo`` (not the alias). """ kind: UseKind line: int leaf: NodeOrLeaf import_module: str | None = None import_alias: str | None = None # Fields below help the rename pass without re-walking. _annotation_marker: bool = field(default=False, repr=False, compare=False)
# --------------------------------------------------------------------------- # find_class_definitions # ---------------------------------------------------------------------------
[docs] def find_class_definitions(source_tree: SourceTree) -> list[ClassDefRef]: """Top-level class definitions, in source order. Classes nested inside functions (or other classes) are not returned — they are not project-rename candidates. """ out: list[ClassDefRef] = [] for classdef in source_tree.iter_classes(): out.append(_classdef_ref(classdef)) return out
def _classdef_ref(classdef: Class) -> ClassDefRef: name_leaf = classdef.name bases = _extract_bare_bases(classdef) return ClassDefRef( name=name_leaf.value, bases=bases, line=classdef.children[0].start_pos[0], # the `class` keyword line classdef_node=classdef, ) def _extract_bare_bases(classdef: Class) -> tuple[str, ...]: """Return the bare-name bases of ``classdef``. ``class Foo(A, B):`` → ``("A", "B")``. Dotted (``mod.A``) and expression bases are skipped, as are kwargs (``metaclass=Meta``). """ children = classdef.children # children: ['class', name, '(', <bases>?, ')', ':', suite] or ['class', name, ':', suite] if len(children) < 4 or children[2].type != "operator" or children[2].value != "(": return () # The bases node sits between `(` and `)`. It can be missing (empty parens), # a single `name`, an `arglist`, or another expression we will skip. bases_slot = children[3] if bases_slot.type == "operator" and bases_slot.value == ")": return () bare: list[str] = [] if bases_slot.type == "name": bare.append(bases_slot.value) elif bases_slot.type == "arglist": for c in bases_slot.children: if c.type == "name": bare.append(c.value) # Skip operators (commas), `argument` (kwargs), dotted_name, atom_expr, etc. # Dotted / expression base — skip. return tuple(bare) # --------------------------------------------------------------------------- # find_class_uses # ---------------------------------------------------------------------------
[docs] def find_class_uses(source_tree: SourceTree, class_name: str) -> list[UseSiteRef]: """Every site referencing ``class_name``, classified by kind. Scope-aware: a local ``<class_name> = …`` assignment inside a function (or at module level) shadows subsequent uses of the name in that scope. Conservative: when uncertain, include the site. """ out: list[UseSiteRef] = [] _collect_uses(source_tree.module, class_name, shadowed=False, out=out) return out
def _collect_uses( scope_node: NodeOrLeaf, class_name: str, *, shadowed: bool, out: list[UseSiteRef], ) -> None: """Walk ``scope_node``'s direct children, tracking shadow state. Shadow detection: when a ``simple_stmt`` of the form ``<class_name> = <rhs>`` appears in this scope, all later occurrences of the name in the same scope (including inside nested control-flow constructs) are suppressed. Function and class bodies open a fresh scope — shadow does not propagate into them. """ children = getattr(scope_node, "children", None) if children is None: return for child in children: ctype = child.type if ctype == "simple_stmt": if _is_shadow_assignment(child, class_name): # Walk the RHS for any uses of class_name (Foo = Foo+1 is rare # but legal) before flipping the shadow flag. if not shadowed: _classify_simple_stmt_rhs(child, class_name, out) shadowed = True continue if not shadowed: _collect_uses_from_simple_stmt(child, class_name, out) elif ctype == "classdef": # Look at the classdef's bases (in this scope) before recursing # into the body (a fresh scope). _collect_uses_from_classdef(child, class_name, out) suite = child.children[-1] _collect_uses(suite, class_name, shadowed=False, out=out) elif ctype == "decorated": # children: decorator(s), classdef|funcdef. Decorators may # reference class_name (e.g. @Player.method) — classify as # general expressions in the current scope. for sub in child.children: if sub.type == "decorator": if not shadowed: _classify_general(sub, class_name, out, in_annotation=False) elif sub.type == "classdef": _collect_uses_from_classdef(sub, class_name, out) _collect_uses(sub.children[-1], class_name, shadowed=False, out=out) elif sub.type == "funcdef": _collect_uses_from_funcdef_signature(sub, class_name, out, scope_shadowed=shadowed) _collect_uses(sub.children[-1], class_name, shadowed=False, out=out) elif ctype == "funcdef": _collect_uses_from_funcdef_signature(child, class_name, out, scope_shadowed=shadowed) _collect_uses(child.children[-1], class_name, shadowed=False, out=out) elif ctype in ( "if_stmt", "while_stmt", "for_stmt", "try_stmt", "with_stmt", "suite", "async_stmt", ): # Compound statements share the enclosing scope (Python rule); # shadow propagates. The control-flow header (the condition / # iterable / context expression) is also part of the same scope, # so we walk all children with the same shadow state. if not shadowed: # Classify expression children directly so we don't miss # uses in conditions, iterables, with-context exprs, etc. for c in child.children: _walk_compound_child(c, class_name, shadowed=shadowed, out=out) else: # Even when shadowed at this scope, nested funcdefs/classdefs # still need traversal (they open fresh scopes). for c in child.children: if c.type in ("funcdef", "classdef", "decorated"): _walk_compound_child(c, class_name, shadowed=True, out=out) # newline / endmarker / keyword leaves at top level → ignore def _walk_compound_child( child: NodeOrLeaf, class_name: str, *, shadowed: bool, out: list[UseSiteRef], ) -> None: """Dispatch a single child of a compound statement (if/while/for/with/try).""" ctype = child.type if ctype == "simple_stmt": if _is_shadow_assignment(child, class_name): if not shadowed: _classify_simple_stmt_rhs(child, class_name, out) # Shadow flips for the rest of the scope — but compound child # iteration here is one level down. The outer _collect_uses # already handles the scope-level state; here we just emit. return if not shadowed: _collect_uses_from_simple_stmt(child, class_name, out) return if ctype == "classdef": _collect_uses_from_classdef(child, class_name, out) _collect_uses(child.children[-1], class_name, shadowed=False, out=out) return if ctype == "decorated": for sub in child.children: if sub.type == "decorator": if not shadowed: _classify_general(sub, class_name, out, in_annotation=False) elif sub.type == "classdef": _collect_uses_from_classdef(sub, class_name, out) _collect_uses(sub.children[-1], class_name, shadowed=False, out=out) elif sub.type == "funcdef": _collect_uses_from_funcdef_signature(sub, class_name, out, scope_shadowed=shadowed) _collect_uses(sub.children[-1], class_name, shadowed=False, out=out) return if ctype == "funcdef": _collect_uses_from_funcdef_signature(child, class_name, out, scope_shadowed=shadowed) _collect_uses(child.children[-1], class_name, shadowed=False, out=out) return if ctype in ( "if_stmt", "while_stmt", "for_stmt", "try_stmt", "with_stmt", "suite", "async_stmt", ): # Recurse — same scope semantics as _collect_uses' compound branch. if not shadowed: for c in child.children: _walk_compound_child(c, class_name, shadowed=shadowed, out=out) else: for c in child.children: if c.type in ("funcdef", "classdef", "decorated"): _walk_compound_child(c, class_name, shadowed=True, out=out) return # Otherwise: a header expression (e.g. the condition of `if`, the # iterable of `for`, the context expr of `with`) or a keyword/operator # leaf. Classify expressions; ignore leaves. if not shadowed: if isinstance(child, Leaf): return _classify_general(child, class_name, out, in_annotation=False) def _classify_simple_stmt_rhs(stmt: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None: """For a shadow assignment ``Foo = <rhs>``, classify uses inside <rhs> only. The LHS leaf is the shadow target; emitting it would be a false positive. """ expr = stmt.children[0] if expr.type != "expr_stmt" or len(expr.children) < 3: return # children: [target name, '=', value] rhs = expr.children[2] _classify_general(rhs, class_name, out, in_annotation=False) def _is_shadow_assignment(stmt: NodeOrLeaf, class_name: str) -> bool: """True iff ``stmt`` is ``<class_name> = …`` (a simple assignment). Conservative: only matches plain ``name = value`` (no augmented or tuple targets, no annotated assignment). An annotated assignment like ``Foo: SomeType = …`` is not treated as shadowing — Foo has a type annotation hinting it's still a class-typed binding. """ if stmt.type != "simple_stmt" or not stmt.children: return False expr = stmt.children[0] if expr.type != "expr_stmt" or len(expr.children) < 3: return False target = expr.children[0] eq = expr.children[1] if target.type != "name" or target.value != class_name: return False if not isinstance(eq, Leaf) or eq.type != "operator" or eq.value != "=": return False return True def _collect_uses_from_simple_stmt(stmt: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None: """Classify every leaf reference to ``class_name`` inside ``stmt``.""" inner = stmt.children[0] if stmt.children else None if inner is None: return if inner.type == "import_from": _classify_import_from(inner, class_name, out) return if inner.type == "import_name": # ``import Foo`` — class_name as a top-level import target. _classify_import_name(inner, class_name, out) return # Other statements — walk leaves with annotation context. _classify_general(inner, class_name, out, in_annotation=False) def _classify_import_from(import_from: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None: """Inspect ``from <module> import <names>`` for ``class_name``.""" module = _from_module_name(import_from) names_node = _import_from_names_node(import_from) if names_node is None: return if names_node.type == "name": if names_node.value == class_name: out.append( UseSiteRef( kind="import", line=names_node.start_pos[0], leaf=names_node, import_module=module, import_alias=None, ) ) return if names_node.type == "import_as_name": # Single-name aliased import: `from x import Foo as Bar`. src_leaf = names_node.children[0] alias_leaf = names_node.children[2] if len(names_node.children) >= 3 else None if src_leaf.type == "name" and src_leaf.value == class_name: out.append( UseSiteRef( kind="import_alias", line=src_leaf.start_pos[0], leaf=src_leaf, import_module=module, import_alias=alias_leaf.value if alias_leaf is not None else None, ) ) return if names_node.type == "import_as_names": for sub in names_node.children: if sub.type == "name" and sub.value == class_name: out.append( UseSiteRef( kind="import", line=sub.start_pos[0], leaf=sub, import_module=module, import_alias=None, ) ) elif sub.type == "import_as_name": # children: [name, 'as', alias] src_leaf = sub.children[0] alias_leaf = sub.children[2] if len(sub.children) >= 3 else None if src_leaf.type == "name" and src_leaf.value == class_name: out.append( UseSiteRef( kind="import_alias", line=src_leaf.start_pos[0], leaf=src_leaf, import_module=module, import_alias=alias_leaf.value if alias_leaf is not None else None, ) ) def _classify_import_name(import_name: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None: """``import <X>`` — record uses where the imported dotted-name leaf matches. Only matches a *bare* ``import Foo``; ``import a.b.Foo`` etc. are not handled (rare for class imports). """ if len(import_name.children) < 2: return target = import_name.children[1] if target.type == "name" and target.value == class_name: out.append( UseSiteRef( kind="import", line=target.start_pos[0], leaf=target, import_module=None, import_alias=None, ) ) elif target.type == "dotted_as_name": # ``import Foo as Bar`` — children: [name, 'as', alias] src_leaf = target.children[0] alias_leaf = target.children[2] if len(target.children) >= 3 else None if src_leaf.type == "name" and src_leaf.value == class_name: out.append( UseSiteRef( kind="import_alias", line=src_leaf.start_pos[0], leaf=src_leaf, import_module=None, import_alias=alias_leaf.value if alias_leaf is not None else None, ) ) def _collect_uses_from_classdef(classdef: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None: """Detect ``class_name`` used as a base class of ``classdef``.""" children = classdef.children if len(children) < 4 or children[2].type != "operator" or children[2].value != "(": return bases_slot = children[3] if bases_slot.type == "name" and bases_slot.value == class_name: out.append( UseSiteRef( kind="base_class", line=bases_slot.start_pos[0], leaf=bases_slot, ) ) return if bases_slot.type == "arglist": for c in bases_slot.children: if c.type == "name" and c.value == class_name: out.append( UseSiteRef( kind="base_class", line=c.start_pos[0], leaf=c, ) ) def _collect_uses_from_funcdef_signature( funcdef: NodeOrLeaf, class_name: str, out: list[UseSiteRef], *, scope_shadowed: bool, ) -> None: """Detect ``class_name`` in a funcdef's signature: param annotations + defaults + return. The signature is evaluated in the *enclosing* scope, so it respects the shadow flag of that scope (``scope_shadowed``). """ if scope_shadowed: return children = funcdef.children # children: [def, name, parameters, (-> return_anno)?, :, suite] parameters = children[2] # Walk param annotations + defaults explicitly to set context correctly. _classify_param_annotations(parameters, class_name, out) _classify_param_defaults(parameters, class_name, out) # Return annotation, if any. for i, c in enumerate(children): if c is parameters: j = i + 1 if j < len(children) and getattr(children[j], "value", None) == "->": anno = children[j + 1] if j + 1 < len(children) else None if anno is not None: _classify_general(anno, class_name, out, in_annotation=True) break def _classify_param_annotations(parameters: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None: """For each `param` child, classify its annotation slot under in_annotation=True. Parso wraps annotated parameters as ``param > tfpdef > [name, ":", anno]``. Plain (un-annotated) parameters have ``param > [name]`` directly. """ for child in parameters.children: if child.type != "param": continue # The annotation lives inside a tfpdef wrapper if the param is annotated. for sub in child.children: if sub.type == "tfpdef": # tfpdef children: [name, ':', annotation] tsub = sub.children for i, leaf in enumerate(tsub): if isinstance(leaf, Leaf) and leaf.type == "operator" and leaf.value == ":": if i + 1 < len(tsub): _classify_general(tsub[i + 1], class_name, out, in_annotation=True) break break def _classify_param_defaults(parameters: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None: """For each `param` child, classify its default-value slot (after `=`) as a normal expression. The default is a sibling of the param's name/tfpdef inside the param node. """ for child in parameters.children: if child.type != "param": continue sub = child.children for i, leaf in enumerate(sub): if isinstance(leaf, Leaf) and leaf.type == "operator" and leaf.value == "=": if i + 1 < len(sub): _classify_general(sub[i + 1], class_name, out, in_annotation=False) break def _classify_general( node: NodeOrLeaf, class_name: str, out: list[UseSiteRef], *, in_annotation: bool, ) -> None: """Walk ``node`` and emit a UseSiteRef for every bare-name leaf matching ``class_name``. Classifies based on local syntactic context (parent + sibling structure): - inside a `param` annotation slot or return annotation → ``annotation`` - inside an `annassign` annotation slot → ``annotation`` - head of an `atom_expr` ending in a call trailer → ``instantiation`` (or ``isinstance_arg`` if argument to ``isinstance``) - argument to ``isinstance(...)`` → ``isinstance_arg`` - otherwise → ``bare_reference`` """ for leaf, ctx in _walk_name_leaves(node, class_name, in_annotation): out.append(UseSiteRef(kind=ctx, line=leaf.start_pos[0], leaf=leaf)) def _walk_name_leaves( node: NodeOrLeaf, class_name: str, in_annotation: bool, ) -> Iterator[tuple[Leaf, UseKind]]: """Yield (leaf, kind) for every name leaf in ``node`` matching ``class_name``.""" # If this is an annassign, the slot between ':' and '=' is annotation context. ntype = node.type if ntype == "annassign": # annassign children: [':', annotation, ('=', value)?] children = node.children # annotation = children[1] if len(children) >= 2: yield from _walk_name_leaves(children[1], class_name, in_annotation=True) # value (after '=') uses non-annotation context for i, c in enumerate(children): if isinstance(c, Leaf) and c.type == "operator" and c.value == "=": if i + 1 < len(children): yield from _walk_name_leaves(children[i + 1], class_name, in_annotation=False) break return # Leaf — match name leaves directly if isinstance(node, Leaf): if node.type == "name" and node.value == class_name: kind = _classify_name_leaf(node, in_annotation=in_annotation) if kind is not None: yield node, kind return # Node — walk children, special-casing structural shapes children = node.children # atom_expr: [head, trailer, trailer, ...] — instantiation detection if ntype == "atom_expr" and children: head = children[0] if isinstance(head, Leaf) and head.type == "name" and head.value == class_name: # Determine: instantiation? isinstance arg? (latter is recognised # at the call-trailer site, not here.) Or annotation subscript like # ``list[Foo]`` — head is `list`, not Foo, so no concern here. kind = _classify_atom_expr_head(node, in_annotation=in_annotation) if kind is not None: yield head, kind # Walk the trailers (skip head — already handled). for c in children[1:]: yield from _walk_name_leaves(c, class_name, in_annotation=in_annotation) return # Recurse into all children for c in children: yield from _walk_name_leaves(c, class_name, in_annotation=in_annotation) def _classify_name_leaf(leaf: Leaf, *, in_annotation: bool) -> UseKind | None: """Classify a bare ``name`` leaf based on its parent shape. Skip leaves that are part of an attribute access trailer (``.Foo``) — those are not references to the top-level class. """ parent = leaf.parent if parent is None: return "bare_reference" # Trailer ``.Foo`` — leaf is the attribute name, not a class reference. if parent.type == "trailer" and parent.children and parent.children[0].value == ".": return None # As argument to isinstance: arg's parent is `arglist`, whose parent is a # `trailer`, whose previous sibling is a name `isinstance`. if parent.type == "arglist": gp = parent.parent if gp is not None and gp.type == "trailer": ggp = gp.parent if ggp is not None and ggp.type == "atom_expr": head = ggp.children[0] if isinstance(head, Leaf) and head.type == "name" and head.value == "isinstance": # Position inside arglist: 2nd argument is the class. arglist_children = parent.children arg_positions = [c for c in arglist_children if c.type != "operator"] if leaf in arg_positions: idx = arg_positions.index(leaf) # isinstance(x, Foo) → idx 1 is the class arg if idx >= 1: return "isinstance_arg" if in_annotation: return "annotation" return "bare_reference" def _classify_atom_expr_head(atom_expr: NodeOrLeaf, *, in_annotation: bool) -> UseKind | None: """Classify an atom_expr like ``Foo(...)`` or ``Foo.method`` whose head is the class name. The head leaf is already known to match. Returns the kind for the head leaf. """ children = atom_expr.children if len(children) < 2: # Just a bare name with no trailers — that's a `bare_reference`. return "annotation" if in_annotation else "bare_reference" # Look at the trailer(s). The first trailer determines: # - call (`(`) → instantiation candidate # - attribute (`.`) → bare_reference (the head is being used as an object) # - subscript (`[`) → bare_reference (used for indexing / generics) first_trailer = children[1] if first_trailer.type == "trailer" and first_trailer.children: head_op = first_trailer.children[0] if isinstance(head_op, Leaf) and head_op.type == "operator": if head_op.value == "(": # Foo(...). If we're inside an annotation, that's still a # call expression — treat as instantiation (it produces a # value at runtime; in annotations e.g. PEP 593 Annotated() # may use this, but very rare). Conservatively: instantiation. if in_annotation: # In annotation contexts, a bare ``Foo()`` doesn't make # sense as a type — fall back to annotation so renames # treat it like the surrounding annotation. return "annotation" return "instantiation" if head_op.value == ".": # `Foo.method` — head is a bare reference. return "annotation" if in_annotation else "bare_reference" if head_op.value == "[": # `Foo[X]` — head is a bare reference (or generic). return "annotation" if in_annotation else "bare_reference" return "annotation" if in_annotation else "bare_reference" def _from_module_name(import_from: NodeOrLeaf) -> str | None: """Reconstruct the dotted module name from an ``import_from`` node. Returns ``None`` for purely relative imports with no module (e.g., ``from . import x``). """ children = import_from.children parts: list[str] = [] for c in children[1:]: if c.type == "keyword" and c.value == "import": break if c.type == "operator" and c.value == ".": parts.append(".") elif c.type == "name": parts.append(c.value) elif c.type == "dotted_name": parts.append("".join(sub.value for sub in c.children if hasattr(sub, "value"))) text = "".join(parts) return text if text else None def _import_from_names_node(import_from: NodeOrLeaf) -> NodeOrLeaf | None: """Find the names node (after the ``import`` keyword).""" children = import_from.children for i, c in enumerate(children): if c.type == "keyword" and c.value == "import": for nxt in children[i + 1 :]: if nxt.type == "operator" and nxt.value in ("(", ")"): continue return nxt return None # --------------------------------------------------------------------------- # rename_class_in_source # ---------------------------------------------------------------------------
[docs] def rename_class_in_source(source_tree: SourceTree, old_name: str, new_name: str) -> None: """Rename a class within ``source_tree`` in-place. Rewrites the leaf ``.value`` for: - top-level ``class <old_name>`` definitions (the class name leaf) - every use site returned by :func:`find_class_uses`, except ``import_alias`` where only the source-name leaf is rewritten (the alias is preserved) Idempotent: when ``old_name == new_name`` this is a no-op. Sites inside string literals are not touched (parso's ``string`` leaf type means they never appear as ``name`` leaves in the first place). """ if old_name == new_name: return # Definitions for ref in find_class_definitions(source_tree): if ref.name == old_name: classdef = ref.classdef_node # The class name leaf is always children[1]. name_leaf = classdef.children[1] if isinstance(name_leaf, Leaf) and name_leaf.value == old_name: name_leaf.value = new_name # Use sites — every kind has a leaf to rewrite. import_alias rewrites the # source-name leaf only (alias preserved); subsequent uses of the alias # don't appear in find_class_uses(old_name) at all, so they are naturally # left alone. for use in find_class_uses(source_tree, old_name): leaf = use.leaf if isinstance(leaf, Leaf) and leaf.value == old_name: leaf.value = new_name
# --------------------------------------------------------------------------- # rename_module_in_imports — for file-rename cascades # ---------------------------------------------------------------------------
[docs] def rename_module_in_imports( source_tree: SourceTree, old_module: str, new_module: str, ) -> None: """Rewrite ``from <old_module> import …`` lines to point at ``new_module``. Used when the user renames a class's defining file (or folder) and we need every importer's module path to follow. ``old_module`` and ``new_module`` are dotted names matching the form parso emits in the ``from`` clause (e.g., ``"src.player"`` and ``"src.hero"``). Matches both: - **Exact** matches — `from src.player import X` when ``old_module == "src.player"``. - **Trailing-segment** matches — `from .player import X` and `from foo.bar.player import X` when ``old_module`` ends with the same trailing segment as the import's last component. This handles relative imports cleanly (``.player`` shares the leaf ``player`` with ``src.player`` even though the prefix differs). The match is anchored on the trailing segment to avoid sweeping in unrelated modules: ``from os.path import X`` is left alone when the rename target is unrelated to ``path``. Idempotent: ``rename_module_in_imports(tree, "x", "x")`` is a no-op. Plain ``import x`` statements (no ``from``) are not handled here — they are rare for class imports and would warrant a separate primitive. """ if old_module == new_module: return old_leaf = old_module.rsplit(".", 1)[-1] new_leaf = new_module.rsplit(".", 1)[-1] for import_from in _iter_import_from(source_tree.module): existing = _from_module_name(import_from) if existing is None: continue if existing == old_module: _replace_from_module(import_from, new_module) continue # Trailing-segment match: existing ends with `.<old_leaf>` (or is # the bare leaf, possibly preceded by leading dots for relative # imports). Compute the trailing part after stripping any leading # `.` characters. stripped = existing.lstrip(".") leading_dots = len(existing) - len(stripped) if leading_dots == 0 and "." not in stripped: # Bare absolute import like `from player import X`. We only # match if the bare module equals old_leaf (otherwise it's an # unrelated module). if stripped != old_leaf: continue replacement = new_leaf else: tail = stripped.rsplit(".", 1)[-1] if "." in stripped else stripped if tail != old_leaf: continue # Replace just the trailing segment. if "." in stripped: head = stripped.rsplit(".", 1)[0] replacement = ("." * leading_dots) + head + "." + new_leaf else: replacement = ("." * leading_dots) + new_leaf _replace_from_module(import_from, replacement)
def _iter_import_from(scope_node: NodeOrLeaf) -> Iterator[NodeOrLeaf]: """Walk every ``import_from`` node anywhere in ``scope_node``.""" children = getattr(scope_node, "children", None) if children is None: return for child in children: if child.type == "import_from": yield child elif getattr(child, "children", None) is not None: yield from _iter_import_from(child) def _replace_from_module(import_from: NodeOrLeaf, new_module: str) -> None: """Replace the dotted module portion of an ``import_from`` node in place. The replacement leaf carries the prefix of the first existing module leaf so leading whitespace is preserved. """ children = import_from.children # children: [keyword 'from', <module-parts>, keyword 'import', ...] # The module portion is between the 'from' keyword and the 'import' keyword. start_idx: int | None = None end_idx: int | None = None for i, c in enumerate(children): if c.type == "keyword" and c.value == "from": start_idx = i + 1 elif c.type == "keyword" and c.value == "import": end_idx = i break if start_idx is None or end_idx is None or start_idx >= end_idx: return # Capture the prefix from the first module leaf so we can keep # whitespace stable. first = children[start_idx] first_leaf = first if isinstance(first, Leaf) else first.get_first_leaf() saved_prefix = first_leaf.prefix # Build replacement nodes by parsing a snippet `from <new> import _` # and pulling the module portion out of the inner ``import_from``. from .source_tree import parse_snippet snippet = parse_snippet(f"from {new_module} import _\n") # parse_snippet wraps the statement in a ``simple_stmt`` whose first # child is the ``import_from`` we care about. snippet_inner = None for c in getattr(snippet, "children", ()): if c.type == "import_from": snippet_inner = c break if snippet_inner is None: return inner_children = snippet_inner.children sn_start: int | None = None sn_end: int | None = None for i, c in enumerate(inner_children): if c.type == "keyword" and c.value == "from": sn_start = i + 1 elif c.type == "keyword" and c.value == "import": sn_end = i break if sn_start is None or sn_end is None: return new_module_nodes = inner_children[sn_start:sn_end] if not new_module_nodes: return # Restore the original whitespace prefix on the first new node. new_first = new_module_nodes[0] new_first_leaf = new_first if isinstance(new_first, Leaf) else new_first.get_first_leaf() new_first_leaf.prefix = saved_prefix # Splice: replace [start_idx:end_idx] with new_module_nodes. new_children = list(children[:start_idx]) + list(new_module_nodes) + list(children[end_idx:]) for n in new_module_nodes: n.parent = import_from import_from.children[:] = new_children