"""Pure CST queries over a parso tree for class definitions and use sites.
This module is the foundation for cross-file class rename. It is stateless
and does no I/O — all functions take a :class:`SourceTree` and walk the
parso CST. ``rename_class_in_source`` mutates the tree in place; callers
:meth:`SourceTree.dump` the result.
Use-site classification is purely syntactic (no name resolution): an aliased
import does not propagate to its alias, and a local-variable shadow inside
a function suppresses uses of the class name within that function.
"""
from __future__ import annotations
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Literal
from parso.python.tree import Class
from parso.tree import Leaf, NodeOrLeaf
from .source_tree import SourceTree
__all__ = [
"ClassDefRef",
"UseSiteRef",
"find_class_definitions",
"find_class_uses",
"rename_class_in_source",
"rename_module_in_imports",
]
UseKind = Literal[
"import",
"import_alias",
"base_class",
"instantiation",
"isinstance_arg",
"annotation",
"bare_reference",
]
[docs]
@dataclass(frozen=True)
class ClassDefRef:
"""A top-level class definition discovered in a source tree."""
name: str
bases: tuple[str, ...]
line: int
classdef_node: NodeOrLeaf
[docs]
@dataclass(frozen=True)
class UseSiteRef:
"""A site where a class name is referenced.
``leaf`` is the bare-name leaf to potentially rewrite. For aliased
imports (`from x import Foo as Bar`) ``leaf`` points at the source
name ``Foo`` (not the alias).
"""
kind: UseKind
line: int
leaf: NodeOrLeaf
import_module: str | None = None
import_alias: str | None = None
# Fields below help the rename pass without re-walking.
_annotation_marker: bool = field(default=False, repr=False, compare=False)
# ---------------------------------------------------------------------------
# find_class_definitions
# ---------------------------------------------------------------------------
[docs]
def find_class_definitions(source_tree: SourceTree) -> list[ClassDefRef]:
"""Top-level class definitions, in source order.
Classes nested inside functions (or other classes) are not returned —
they are not project-rename candidates.
"""
out: list[ClassDefRef] = []
for classdef in source_tree.iter_classes():
out.append(_classdef_ref(classdef))
return out
def _classdef_ref(classdef: Class) -> ClassDefRef:
name_leaf = classdef.name
bases = _extract_bare_bases(classdef)
return ClassDefRef(
name=name_leaf.value,
bases=bases,
line=classdef.children[0].start_pos[0], # the `class` keyword line
classdef_node=classdef,
)
def _extract_bare_bases(classdef: Class) -> tuple[str, ...]:
"""Return the bare-name bases of ``classdef``.
``class Foo(A, B):`` → ``("A", "B")``. Dotted (``mod.A``) and expression
bases are skipped, as are kwargs (``metaclass=Meta``).
"""
children = classdef.children
# children: ['class', name, '(', <bases>?, ')', ':', suite] or ['class', name, ':', suite]
if len(children) < 4 or children[2].type != "operator" or children[2].value != "(":
return ()
# The bases node sits between `(` and `)`. It can be missing (empty parens),
# a single `name`, an `arglist`, or another expression we will skip.
bases_slot = children[3]
if bases_slot.type == "operator" and bases_slot.value == ")":
return ()
bare: list[str] = []
if bases_slot.type == "name":
bare.append(bases_slot.value)
elif bases_slot.type == "arglist":
for c in bases_slot.children:
if c.type == "name":
bare.append(c.value)
# Skip operators (commas), `argument` (kwargs), dotted_name, atom_expr, etc.
# Dotted / expression base — skip.
return tuple(bare)
# ---------------------------------------------------------------------------
# find_class_uses
# ---------------------------------------------------------------------------
[docs]
def find_class_uses(source_tree: SourceTree, class_name: str) -> list[UseSiteRef]:
"""Every site referencing ``class_name``, classified by kind.
Scope-aware: a local ``<class_name> = …`` assignment inside a function
(or at module level) shadows subsequent uses of the name in that scope.
Conservative: when uncertain, include the site.
"""
out: list[UseSiteRef] = []
_collect_uses(source_tree.module, class_name, shadowed=False, out=out)
return out
def _collect_uses(
scope_node: NodeOrLeaf,
class_name: str,
*,
shadowed: bool,
out: list[UseSiteRef],
) -> None:
"""Walk ``scope_node``'s direct children, tracking shadow state.
Shadow detection: when a ``simple_stmt`` of the form ``<class_name> = <rhs>``
appears in this scope, all later occurrences of the name in the same
scope (including inside nested control-flow constructs) are suppressed.
Function and class bodies open a fresh scope — shadow does not propagate
into them.
"""
children = getattr(scope_node, "children", None)
if children is None:
return
for child in children:
ctype = child.type
if ctype == "simple_stmt":
if _is_shadow_assignment(child, class_name):
# Walk the RHS for any uses of class_name (Foo = Foo+1 is rare
# but legal) before flipping the shadow flag.
if not shadowed:
_classify_simple_stmt_rhs(child, class_name, out)
shadowed = True
continue
if not shadowed:
_collect_uses_from_simple_stmt(child, class_name, out)
elif ctype == "classdef":
# Look at the classdef's bases (in this scope) before recursing
# into the body (a fresh scope).
_collect_uses_from_classdef(child, class_name, out)
suite = child.children[-1]
_collect_uses(suite, class_name, shadowed=False, out=out)
elif ctype == "decorated":
# children: decorator(s), classdef|funcdef. Decorators may
# reference class_name (e.g. @Player.method) — classify as
# general expressions in the current scope.
for sub in child.children:
if sub.type == "decorator":
if not shadowed:
_classify_general(sub, class_name, out, in_annotation=False)
elif sub.type == "classdef":
_collect_uses_from_classdef(sub, class_name, out)
_collect_uses(sub.children[-1], class_name, shadowed=False, out=out)
elif sub.type == "funcdef":
_collect_uses_from_funcdef_signature(sub, class_name, out, scope_shadowed=shadowed)
_collect_uses(sub.children[-1], class_name, shadowed=False, out=out)
elif ctype == "funcdef":
_collect_uses_from_funcdef_signature(child, class_name, out, scope_shadowed=shadowed)
_collect_uses(child.children[-1], class_name, shadowed=False, out=out)
elif ctype in (
"if_stmt",
"while_stmt",
"for_stmt",
"try_stmt",
"with_stmt",
"suite",
"async_stmt",
):
# Compound statements share the enclosing scope (Python rule);
# shadow propagates. The control-flow header (the condition /
# iterable / context expression) is also part of the same scope,
# so we walk all children with the same shadow state.
if not shadowed:
# Classify expression children directly so we don't miss
# uses in conditions, iterables, with-context exprs, etc.
for c in child.children:
_walk_compound_child(c, class_name, shadowed=shadowed, out=out)
else:
# Even when shadowed at this scope, nested funcdefs/classdefs
# still need traversal (they open fresh scopes).
for c in child.children:
if c.type in ("funcdef", "classdef", "decorated"):
_walk_compound_child(c, class_name, shadowed=True, out=out)
# newline / endmarker / keyword leaves at top level → ignore
def _walk_compound_child(
child: NodeOrLeaf,
class_name: str,
*,
shadowed: bool,
out: list[UseSiteRef],
) -> None:
"""Dispatch a single child of a compound statement (if/while/for/with/try)."""
ctype = child.type
if ctype == "simple_stmt":
if _is_shadow_assignment(child, class_name):
if not shadowed:
_classify_simple_stmt_rhs(child, class_name, out)
# Shadow flips for the rest of the scope — but compound child
# iteration here is one level down. The outer _collect_uses
# already handles the scope-level state; here we just emit.
return
if not shadowed:
_collect_uses_from_simple_stmt(child, class_name, out)
return
if ctype == "classdef":
_collect_uses_from_classdef(child, class_name, out)
_collect_uses(child.children[-1], class_name, shadowed=False, out=out)
return
if ctype == "decorated":
for sub in child.children:
if sub.type == "decorator":
if not shadowed:
_classify_general(sub, class_name, out, in_annotation=False)
elif sub.type == "classdef":
_collect_uses_from_classdef(sub, class_name, out)
_collect_uses(sub.children[-1], class_name, shadowed=False, out=out)
elif sub.type == "funcdef":
_collect_uses_from_funcdef_signature(sub, class_name, out, scope_shadowed=shadowed)
_collect_uses(sub.children[-1], class_name, shadowed=False, out=out)
return
if ctype == "funcdef":
_collect_uses_from_funcdef_signature(child, class_name, out, scope_shadowed=shadowed)
_collect_uses(child.children[-1], class_name, shadowed=False, out=out)
return
if ctype in (
"if_stmt",
"while_stmt",
"for_stmt",
"try_stmt",
"with_stmt",
"suite",
"async_stmt",
):
# Recurse — same scope semantics as _collect_uses' compound branch.
if not shadowed:
for c in child.children:
_walk_compound_child(c, class_name, shadowed=shadowed, out=out)
else:
for c in child.children:
if c.type in ("funcdef", "classdef", "decorated"):
_walk_compound_child(c, class_name, shadowed=True, out=out)
return
# Otherwise: a header expression (e.g. the condition of `if`, the
# iterable of `for`, the context expr of `with`) or a keyword/operator
# leaf. Classify expressions; ignore leaves.
if not shadowed:
if isinstance(child, Leaf):
return
_classify_general(child, class_name, out, in_annotation=False)
def _classify_simple_stmt_rhs(stmt: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None:
"""For a shadow assignment ``Foo = <rhs>``, classify uses inside <rhs> only.
The LHS leaf is the shadow target; emitting it would be a false positive.
"""
expr = stmt.children[0]
if expr.type != "expr_stmt" or len(expr.children) < 3:
return
# children: [target name, '=', value]
rhs = expr.children[2]
_classify_general(rhs, class_name, out, in_annotation=False)
def _is_shadow_assignment(stmt: NodeOrLeaf, class_name: str) -> bool:
"""True iff ``stmt`` is ``<class_name> = …`` (a simple assignment).
Conservative: only matches plain ``name = value`` (no augmented or tuple
targets, no annotated assignment). An annotated assignment like
``Foo: SomeType = …`` is not treated as shadowing — Foo has a type
annotation hinting it's still a class-typed binding.
"""
if stmt.type != "simple_stmt" or not stmt.children:
return False
expr = stmt.children[0]
if expr.type != "expr_stmt" or len(expr.children) < 3:
return False
target = expr.children[0]
eq = expr.children[1]
if target.type != "name" or target.value != class_name:
return False
if not isinstance(eq, Leaf) or eq.type != "operator" or eq.value != "=":
return False
return True
def _collect_uses_from_simple_stmt(stmt: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None:
"""Classify every leaf reference to ``class_name`` inside ``stmt``."""
inner = stmt.children[0] if stmt.children else None
if inner is None:
return
if inner.type == "import_from":
_classify_import_from(inner, class_name, out)
return
if inner.type == "import_name":
# ``import Foo`` — class_name as a top-level import target.
_classify_import_name(inner, class_name, out)
return
# Other statements — walk leaves with annotation context.
_classify_general(inner, class_name, out, in_annotation=False)
def _classify_import_from(import_from: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None:
"""Inspect ``from <module> import <names>`` for ``class_name``."""
module = _from_module_name(import_from)
names_node = _import_from_names_node(import_from)
if names_node is None:
return
if names_node.type == "name":
if names_node.value == class_name:
out.append(
UseSiteRef(
kind="import",
line=names_node.start_pos[0],
leaf=names_node,
import_module=module,
import_alias=None,
)
)
return
if names_node.type == "import_as_name":
# Single-name aliased import: `from x import Foo as Bar`.
src_leaf = names_node.children[0]
alias_leaf = names_node.children[2] if len(names_node.children) >= 3 else None
if src_leaf.type == "name" and src_leaf.value == class_name:
out.append(
UseSiteRef(
kind="import_alias",
line=src_leaf.start_pos[0],
leaf=src_leaf,
import_module=module,
import_alias=alias_leaf.value if alias_leaf is not None else None,
)
)
return
if names_node.type == "import_as_names":
for sub in names_node.children:
if sub.type == "name" and sub.value == class_name:
out.append(
UseSiteRef(
kind="import",
line=sub.start_pos[0],
leaf=sub,
import_module=module,
import_alias=None,
)
)
elif sub.type == "import_as_name":
# children: [name, 'as', alias]
src_leaf = sub.children[0]
alias_leaf = sub.children[2] if len(sub.children) >= 3 else None
if src_leaf.type == "name" and src_leaf.value == class_name:
out.append(
UseSiteRef(
kind="import_alias",
line=src_leaf.start_pos[0],
leaf=src_leaf,
import_module=module,
import_alias=alias_leaf.value if alias_leaf is not None else None,
)
)
def _classify_import_name(import_name: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None:
"""``import <X>`` — record uses where the imported dotted-name leaf matches.
Only matches a *bare* ``import Foo``; ``import a.b.Foo`` etc. are not
handled (rare for class imports).
"""
if len(import_name.children) < 2:
return
target = import_name.children[1]
if target.type == "name" and target.value == class_name:
out.append(
UseSiteRef(
kind="import",
line=target.start_pos[0],
leaf=target,
import_module=None,
import_alias=None,
)
)
elif target.type == "dotted_as_name":
# ``import Foo as Bar`` — children: [name, 'as', alias]
src_leaf = target.children[0]
alias_leaf = target.children[2] if len(target.children) >= 3 else None
if src_leaf.type == "name" and src_leaf.value == class_name:
out.append(
UseSiteRef(
kind="import_alias",
line=src_leaf.start_pos[0],
leaf=src_leaf,
import_module=None,
import_alias=alias_leaf.value if alias_leaf is not None else None,
)
)
def _collect_uses_from_classdef(classdef: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None:
"""Detect ``class_name`` used as a base class of ``classdef``."""
children = classdef.children
if len(children) < 4 or children[2].type != "operator" or children[2].value != "(":
return
bases_slot = children[3]
if bases_slot.type == "name" and bases_slot.value == class_name:
out.append(
UseSiteRef(
kind="base_class",
line=bases_slot.start_pos[0],
leaf=bases_slot,
)
)
return
if bases_slot.type == "arglist":
for c in bases_slot.children:
if c.type == "name" and c.value == class_name:
out.append(
UseSiteRef(
kind="base_class",
line=c.start_pos[0],
leaf=c,
)
)
def _collect_uses_from_funcdef_signature(
funcdef: NodeOrLeaf,
class_name: str,
out: list[UseSiteRef],
*,
scope_shadowed: bool,
) -> None:
"""Detect ``class_name`` in a funcdef's signature: param annotations + defaults + return.
The signature is evaluated in the *enclosing* scope, so it respects the
shadow flag of that scope (``scope_shadowed``).
"""
if scope_shadowed:
return
children = funcdef.children
# children: [def, name, parameters, (-> return_anno)?, :, suite]
parameters = children[2]
# Walk param annotations + defaults explicitly to set context correctly.
_classify_param_annotations(parameters, class_name, out)
_classify_param_defaults(parameters, class_name, out)
# Return annotation, if any.
for i, c in enumerate(children):
if c is parameters:
j = i + 1
if j < len(children) and getattr(children[j], "value", None) == "->":
anno = children[j + 1] if j + 1 < len(children) else None
if anno is not None:
_classify_general(anno, class_name, out, in_annotation=True)
break
def _classify_param_annotations(parameters: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None:
"""For each `param` child, classify its annotation slot under in_annotation=True.
Parso wraps annotated parameters as ``param > tfpdef > [name, ":", anno]``.
Plain (un-annotated) parameters have ``param > [name]`` directly.
"""
for child in parameters.children:
if child.type != "param":
continue
# The annotation lives inside a tfpdef wrapper if the param is annotated.
for sub in child.children:
if sub.type == "tfpdef":
# tfpdef children: [name, ':', annotation]
tsub = sub.children
for i, leaf in enumerate(tsub):
if isinstance(leaf, Leaf) and leaf.type == "operator" and leaf.value == ":":
if i + 1 < len(tsub):
_classify_general(tsub[i + 1], class_name, out, in_annotation=True)
break
break
def _classify_param_defaults(parameters: NodeOrLeaf, class_name: str, out: list[UseSiteRef]) -> None:
"""For each `param` child, classify its default-value slot (after `=`) as a normal expression.
The default is a sibling of the param's name/tfpdef inside the param node.
"""
for child in parameters.children:
if child.type != "param":
continue
sub = child.children
for i, leaf in enumerate(sub):
if isinstance(leaf, Leaf) and leaf.type == "operator" and leaf.value == "=":
if i + 1 < len(sub):
_classify_general(sub[i + 1], class_name, out, in_annotation=False)
break
def _classify_general(
node: NodeOrLeaf,
class_name: str,
out: list[UseSiteRef],
*,
in_annotation: bool,
) -> None:
"""Walk ``node`` and emit a UseSiteRef for every bare-name leaf matching ``class_name``.
Classifies based on local syntactic context (parent + sibling structure):
- inside a `param` annotation slot or return annotation → ``annotation``
- inside an `annassign` annotation slot → ``annotation``
- head of an `atom_expr` ending in a call trailer → ``instantiation``
(or ``isinstance_arg`` if argument to ``isinstance``)
- argument to ``isinstance(...)`` → ``isinstance_arg``
- otherwise → ``bare_reference``
"""
for leaf, ctx in _walk_name_leaves(node, class_name, in_annotation):
out.append(UseSiteRef(kind=ctx, line=leaf.start_pos[0], leaf=leaf))
def _walk_name_leaves(
node: NodeOrLeaf,
class_name: str,
in_annotation: bool,
) -> Iterator[tuple[Leaf, UseKind]]:
"""Yield (leaf, kind) for every name leaf in ``node`` matching ``class_name``."""
# If this is an annassign, the slot between ':' and '=' is annotation context.
ntype = node.type
if ntype == "annassign":
# annassign children: [':', annotation, ('=', value)?]
children = node.children
# annotation = children[1]
if len(children) >= 2:
yield from _walk_name_leaves(children[1], class_name, in_annotation=True)
# value (after '=') uses non-annotation context
for i, c in enumerate(children):
if isinstance(c, Leaf) and c.type == "operator" and c.value == "=":
if i + 1 < len(children):
yield from _walk_name_leaves(children[i + 1], class_name, in_annotation=False)
break
return
# Leaf — match name leaves directly
if isinstance(node, Leaf):
if node.type == "name" and node.value == class_name:
kind = _classify_name_leaf(node, in_annotation=in_annotation)
if kind is not None:
yield node, kind
return
# Node — walk children, special-casing structural shapes
children = node.children
# atom_expr: [head, trailer, trailer, ...] — instantiation detection
if ntype == "atom_expr" and children:
head = children[0]
if isinstance(head, Leaf) and head.type == "name" and head.value == class_name:
# Determine: instantiation? isinstance arg? (latter is recognised
# at the call-trailer site, not here.) Or annotation subscript like
# ``list[Foo]`` — head is `list`, not Foo, so no concern here.
kind = _classify_atom_expr_head(node, in_annotation=in_annotation)
if kind is not None:
yield head, kind
# Walk the trailers (skip head — already handled).
for c in children[1:]:
yield from _walk_name_leaves(c, class_name, in_annotation=in_annotation)
return
# Recurse into all children
for c in children:
yield from _walk_name_leaves(c, class_name, in_annotation=in_annotation)
def _classify_name_leaf(leaf: Leaf, *, in_annotation: bool) -> UseKind | None:
"""Classify a bare ``name`` leaf based on its parent shape.
Skip leaves that are part of an attribute access trailer (``.Foo``) —
those are not references to the top-level class.
"""
parent = leaf.parent
if parent is None:
return "bare_reference"
# Trailer ``.Foo`` — leaf is the attribute name, not a class reference.
if parent.type == "trailer" and parent.children and parent.children[0].value == ".":
return None
# As argument to isinstance: arg's parent is `arglist`, whose parent is a
# `trailer`, whose previous sibling is a name `isinstance`.
if parent.type == "arglist":
gp = parent.parent
if gp is not None and gp.type == "trailer":
ggp = gp.parent
if ggp is not None and ggp.type == "atom_expr":
head = ggp.children[0]
if isinstance(head, Leaf) and head.type == "name" and head.value == "isinstance":
# Position inside arglist: 2nd argument is the class.
arglist_children = parent.children
arg_positions = [c for c in arglist_children if c.type != "operator"]
if leaf in arg_positions:
idx = arg_positions.index(leaf)
# isinstance(x, Foo) → idx 1 is the class arg
if idx >= 1:
return "isinstance_arg"
if in_annotation:
return "annotation"
return "bare_reference"
def _classify_atom_expr_head(atom_expr: NodeOrLeaf, *, in_annotation: bool) -> UseKind | None:
"""Classify an atom_expr like ``Foo(...)`` or ``Foo.method`` whose head is the class name.
The head leaf is already known to match. Returns the kind for the head leaf.
"""
children = atom_expr.children
if len(children) < 2:
# Just a bare name with no trailers — that's a `bare_reference`.
return "annotation" if in_annotation else "bare_reference"
# Look at the trailer(s). The first trailer determines:
# - call (`(`) → instantiation candidate
# - attribute (`.`) → bare_reference (the head is being used as an object)
# - subscript (`[`) → bare_reference (used for indexing / generics)
first_trailer = children[1]
if first_trailer.type == "trailer" and first_trailer.children:
head_op = first_trailer.children[0]
if isinstance(head_op, Leaf) and head_op.type == "operator":
if head_op.value == "(":
# Foo(...). If we're inside an annotation, that's still a
# call expression — treat as instantiation (it produces a
# value at runtime; in annotations e.g. PEP 593 Annotated()
# may use this, but very rare). Conservatively: instantiation.
if in_annotation:
# In annotation contexts, a bare ``Foo()`` doesn't make
# sense as a type — fall back to annotation so renames
# treat it like the surrounding annotation.
return "annotation"
return "instantiation"
if head_op.value == ".":
# `Foo.method` — head is a bare reference.
return "annotation" if in_annotation else "bare_reference"
if head_op.value == "[":
# `Foo[X]` — head is a bare reference (or generic).
return "annotation" if in_annotation else "bare_reference"
return "annotation" if in_annotation else "bare_reference"
def _from_module_name(import_from: NodeOrLeaf) -> str | None:
"""Reconstruct the dotted module name from an ``import_from`` node.
Returns ``None`` for purely relative imports with no module
(e.g., ``from . import x``).
"""
children = import_from.children
parts: list[str] = []
for c in children[1:]:
if c.type == "keyword" and c.value == "import":
break
if c.type == "operator" and c.value == ".":
parts.append(".")
elif c.type == "name":
parts.append(c.value)
elif c.type == "dotted_name":
parts.append("".join(sub.value for sub in c.children if hasattr(sub, "value")))
text = "".join(parts)
return text if text else None
def _import_from_names_node(import_from: NodeOrLeaf) -> NodeOrLeaf | None:
"""Find the names node (after the ``import`` keyword)."""
children = import_from.children
for i, c in enumerate(children):
if c.type == "keyword" and c.value == "import":
for nxt in children[i + 1 :]:
if nxt.type == "operator" and nxt.value in ("(", ")"):
continue
return nxt
return None
# ---------------------------------------------------------------------------
# rename_class_in_source
# ---------------------------------------------------------------------------
[docs]
def rename_class_in_source(source_tree: SourceTree, old_name: str, new_name: str) -> None:
"""Rename a class within ``source_tree`` in-place.
Rewrites the leaf ``.value`` for:
- top-level ``class <old_name>`` definitions (the class name leaf)
- every use site returned by :func:`find_class_uses`, except ``import_alias``
where only the source-name leaf is rewritten (the alias is preserved)
Idempotent: when ``old_name == new_name`` this is a no-op. Sites inside
string literals are not touched (parso's ``string`` leaf type means they
never appear as ``name`` leaves in the first place).
"""
if old_name == new_name:
return
# Definitions
for ref in find_class_definitions(source_tree):
if ref.name == old_name:
classdef = ref.classdef_node
# The class name leaf is always children[1].
name_leaf = classdef.children[1]
if isinstance(name_leaf, Leaf) and name_leaf.value == old_name:
name_leaf.value = new_name
# Use sites — every kind has a leaf to rewrite. import_alias rewrites the
# source-name leaf only (alias preserved); subsequent uses of the alias
# don't appear in find_class_uses(old_name) at all, so they are naturally
# left alone.
for use in find_class_uses(source_tree, old_name):
leaf = use.leaf
if isinstance(leaf, Leaf) and leaf.value == old_name:
leaf.value = new_name
# ---------------------------------------------------------------------------
# rename_module_in_imports — for file-rename cascades
# ---------------------------------------------------------------------------
[docs]
def rename_module_in_imports(
source_tree: SourceTree,
old_module: str,
new_module: str,
) -> None:
"""Rewrite ``from <old_module> import …`` lines to point at ``new_module``.
Used when the user renames a class's defining file (or folder) and we
need every importer's module path to follow. ``old_module`` and
``new_module`` are dotted names matching the form parso emits in the
``from`` clause (e.g., ``"src.player"`` and ``"src.hero"``).
Matches both:
- **Exact** matches — `from src.player import X` when ``old_module ==
"src.player"``.
- **Trailing-segment** matches — `from .player import X` and
`from foo.bar.player import X` when ``old_module`` ends with the
same trailing segment as the import's last component. This handles
relative imports cleanly (``.player`` shares the leaf ``player``
with ``src.player`` even though the prefix differs).
The match is anchored on the trailing segment to avoid sweeping in
unrelated modules: ``from os.path import X`` is left alone when the
rename target is unrelated to ``path``.
Idempotent: ``rename_module_in_imports(tree, "x", "x")`` is a no-op.
Plain ``import x`` statements (no ``from``) are not handled here —
they are rare for class imports and would warrant a separate primitive.
"""
if old_module == new_module:
return
old_leaf = old_module.rsplit(".", 1)[-1]
new_leaf = new_module.rsplit(".", 1)[-1]
for import_from in _iter_import_from(source_tree.module):
existing = _from_module_name(import_from)
if existing is None:
continue
if existing == old_module:
_replace_from_module(import_from, new_module)
continue
# Trailing-segment match: existing ends with `.<old_leaf>` (or is
# the bare leaf, possibly preceded by leading dots for relative
# imports). Compute the trailing part after stripping any leading
# `.` characters.
stripped = existing.lstrip(".")
leading_dots = len(existing) - len(stripped)
if leading_dots == 0 and "." not in stripped:
# Bare absolute import like `from player import X`. We only
# match if the bare module equals old_leaf (otherwise it's an
# unrelated module).
if stripped != old_leaf:
continue
replacement = new_leaf
else:
tail = stripped.rsplit(".", 1)[-1] if "." in stripped else stripped
if tail != old_leaf:
continue
# Replace just the trailing segment.
if "." in stripped:
head = stripped.rsplit(".", 1)[0]
replacement = ("." * leading_dots) + head + "." + new_leaf
else:
replacement = ("." * leading_dots) + new_leaf
_replace_from_module(import_from, replacement)
def _iter_import_from(scope_node: NodeOrLeaf) -> Iterator[NodeOrLeaf]:
"""Walk every ``import_from`` node anywhere in ``scope_node``."""
children = getattr(scope_node, "children", None)
if children is None:
return
for child in children:
if child.type == "import_from":
yield child
elif getattr(child, "children", None) is not None:
yield from _iter_import_from(child)
def _replace_from_module(import_from: NodeOrLeaf, new_module: str) -> None:
"""Replace the dotted module portion of an ``import_from`` node in place.
The replacement leaf carries the prefix of the first existing module
leaf so leading whitespace is preserved.
"""
children = import_from.children
# children: [keyword 'from', <module-parts>, keyword 'import', ...]
# The module portion is between the 'from' keyword and the 'import' keyword.
start_idx: int | None = None
end_idx: int | None = None
for i, c in enumerate(children):
if c.type == "keyword" and c.value == "from":
start_idx = i + 1
elif c.type == "keyword" and c.value == "import":
end_idx = i
break
if start_idx is None or end_idx is None or start_idx >= end_idx:
return
# Capture the prefix from the first module leaf so we can keep
# whitespace stable.
first = children[start_idx]
first_leaf = first if isinstance(first, Leaf) else first.get_first_leaf()
saved_prefix = first_leaf.prefix
# Build replacement nodes by parsing a snippet `from <new> import _`
# and pulling the module portion out of the inner ``import_from``.
from .source_tree import parse_snippet
snippet = parse_snippet(f"from {new_module} import _\n")
# parse_snippet wraps the statement in a ``simple_stmt`` whose first
# child is the ``import_from`` we care about.
snippet_inner = None
for c in getattr(snippet, "children", ()):
if c.type == "import_from":
snippet_inner = c
break
if snippet_inner is None:
return
inner_children = snippet_inner.children
sn_start: int | None = None
sn_end: int | None = None
for i, c in enumerate(inner_children):
if c.type == "keyword" and c.value == "from":
sn_start = i + 1
elif c.type == "keyword" and c.value == "import":
sn_end = i
break
if sn_start is None or sn_end is None:
return
new_module_nodes = inner_children[sn_start:sn_end]
if not new_module_nodes:
return
# Restore the original whitespace prefix on the first new node.
new_first = new_module_nodes[0]
new_first_leaf = new_first if isinstance(new_first, Leaf) else new_first.get_first_leaf()
new_first_leaf.prefix = saved_prefix
# Splice: replace [start_idx:end_idx] with new_module_nodes.
new_children = list(children[:start_idx]) + list(new_module_nodes) + list(children[end_idx:])
for n in new_module_nodes:
n.parent = import_from
import_from.children[:] = new_children