Source code for simvx.core.input.state

"""Input — input state tracker. Instance-based with module-level default."""

import contextvars
from contextlib import contextmanager

from ..math.types import Vec2
from .enums import JoyAxis, JoyButton, Key, MouseButton, MouseCaptureMode
from .events import InputBinding
from .map import InputMap as _default_map
from .map import _InputMap


class _Input:
    """Input state tracker. Create new instances for per-tree isolation.

    Tracks keyboard, mouse, and gamepad state.
    Actions are registered via an InputMap; query with is_action_pressed() etc.
    Direct key/button queries use is_key_pressed(Key.X) / is_mouse_button_pressed(MouseButton.LEFT).
    """

    def __init__(self, input_map: _InputMap | None = None):
        self.input_map = input_map if input_map is not None else _default_map

        # --- Touch emulation (mouse <-> touch) ---
        self._emulate_touch_from_mouse: bool = False
        self._emulate_mouse_from_touch: bool = True

        # --- String-based state (written by platform adapters, read by UI system) ---
        self._keys: dict[str, bool] = {}
        self._keys_just_pressed: dict[str, bool] = {}
        self._keys_just_released: dict[str, bool] = {}
        self._mouse_pos: tuple[float, float] = (0.0, 0.0)
        self._mouse_delta: tuple[float, float] = (0.0, 0.0)
        self._scroll_delta: tuple[float, float] = (0.0, 0.0)
        self._gamepad_buttons: dict[int, dict[str, bool]] = {}
        self._gamepad_axes: dict[int, dict[str, float]] = {}

        # --- Typed state ---
        self._keys_pressed: set[int] = set()
        self._keys_just_pressed_typed: set[int] = set()
        self._keys_just_released_typed: set[int] = set()
        self._mouse_buttons_pressed: set[int] = set()
        self._mouse_buttons_just_pressed: set[int] = set()
        self._mouse_buttons_just_released: set[int] = set()
        self._joy_axes: dict[int, float] = {}
        self._joy_buttons_pressed: set[int] = set()
        self._joy_buttons_just_pressed: set[int] = set()
        self._joy_buttons_just_released: set[int] = set()
        self._capture_mode: MouseCaptureMode = MouseCaptureMode.VISIBLE
        self._capture_mode_callback: object = None  # Platform sets this to apply capture

        # --- Touch input ---
        self._touches: dict[int, tuple[float, float, float]] = {}
        self._touches_just_pressed: dict[int, tuple[float, float, float]] = {}
        self._touches_just_released: set[int] = set()

    # ----------------------------------------------------------------
    # Action queries (via InputMap only)
    # ----------------------------------------------------------------

    def is_action_pressed(self, action: str) -> bool:
        """Check if any input mapped to the action is currently held."""
        for b in self.input_map.get_bindings(action):
            if self._binding_pressed(b):
                return True
        return False

    def is_action_just_pressed(self, action: str) -> bool:
        """Check if any input mapped to the action was pressed this frame."""
        for b in self.input_map.get_bindings(action):
            if self._binding_just_pressed(b):
                return True
        return False

    def is_action_just_released(self, action: str) -> bool:
        """Check if any input mapped to the action was released this frame."""
        for b in self.input_map.get_bindings(action):
            if self._binding_just_released(b):
                return True
        return False

    def get_action_strength(self, action: str) -> float:
        """Return action strength: 1.0 for digital press, analog value for axes."""
        strength = 0.0
        for b in self.input_map.get_bindings(action):
            strength = max(strength, self._binding_strength(b))
        return strength

    # Alias
    get_strength = get_action_strength

    def get_axis(self, negative_action: str, positive_action: str) -> float:
        """Return axis value from two opposing actions. Range [-1, 1]."""
        return self.get_action_strength(positive_action) - self.get_action_strength(negative_action)

    def get_vector(self, neg_x: str, pos_x: str, neg_y: str, pos_y: str) -> Vec2:
        """Return a normalized 2D direction vector from four input actions.

        Handles diagonal normalization so magnitude never exceeds 1.0.
        """
        x = self.get_action_strength(pos_x) - self.get_action_strength(neg_x)
        y = self.get_action_strength(pos_y) - self.get_action_strength(neg_y)
        v = Vec2(x, y)
        ln = v.length()
        return v / ln if ln > 1.0 else v

    # ----------------------------------------------------------------
    # Typed key query API
    # ----------------------------------------------------------------

    def is_key_pressed(self, key: Key) -> bool:
        """Check if a specific key is currently held down."""
        return int(key) in self._keys_pressed

    def is_key_just_pressed(self, key: Key) -> bool:
        """Check if a specific key was pressed this frame (not held from previous)."""
        return int(key) in self._keys_just_pressed_typed

    def is_key_just_released(self, key: Key) -> bool:
        """Check if a specific key was released this frame."""
        return int(key) in self._keys_just_released_typed

    # ----------------------------------------------------------------
    # Mouse
    # ----------------------------------------------------------------

    def is_mouse_button_pressed(self, button: MouseButton) -> bool:
        """Check if a mouse button is currently held."""
        return int(button) in self._mouse_buttons_pressed

    def is_mouse_button_just_pressed(self, button: MouseButton) -> bool:
        """Check if a mouse button was pressed this frame."""
        return int(button) in self._mouse_buttons_just_pressed

    def is_mouse_button_just_released(self, button: MouseButton) -> bool:
        """Check if a mouse button was released this frame."""
        return int(button) in self._mouse_buttons_just_released

    @property
    def mouse_position(self) -> Vec2:
        """Current mouse position in screen coordinates."""
        return Vec2(self._mouse_pos[0], self._mouse_pos[1])

    @property
    def mouse_delta(self) -> Vec2:
        """Mouse movement delta this frame."""
        return Vec2(self._mouse_delta[0], self._mouse_delta[1])

    @property
    def scroll_delta(self) -> tuple[float, float]:
        """Scroll wheel delta this frame (x, y)."""
        return self._scroll_delta

    def set_touch_emulation(self, enabled: bool = True):
        """Enable mouse-to-touch emulation (useful for testing touch on desktop).

        When enabled, left mouse button presses/releases and mouse moves also
        generate touch events with finger_id=0. This lets GestureRecognizer and
        other touch-consuming code work with a mouse.
        """
        self._emulate_touch_from_mouse = enabled

    def set_mouse_from_touch_emulation(self, enabled: bool = True):
        """Control whether the primary touch finger fires synthetic mouse events.

        Default is True: a primary-finger tap fires ``MouseButton.LEFT`` press
        and release events, so InputMap actions bound to the left mouse button
        work on touch devices without per-demo changes. Set to False when the
        application needs to distinguish touch from mouse input (raw touch is
        still delivered via ``touches`` / ``touches_just_pressed``).
        """
        self._emulate_mouse_from_touch = enabled

    def set_mouse_capture_mode(self, mode: MouseCaptureMode):
        """Set mouse cursor capture mode. Platform adapter applies the change."""
        self._capture_mode = mode
        if self._capture_mode_callback:
            self._capture_mode_callback(mode)

    def get_mouse_capture_mode(self) -> MouseCaptureMode:
        """Get the current mouse capture mode."""
        return self._capture_mode

    # ----------------------------------------------------------------
    # Gamepad
    # ----------------------------------------------------------------

    def get_gamepad_axis(self, pad_id: int = 0, axis: str | JoyAxis = "left_x") -> float:
        """Get gamepad axis value [-1, 1].

        Accepts either a string name (legacy) or JoyAxis enum.
        """
        if isinstance(axis, JoyAxis):
            return self._joy_axes.get(int(axis), 0.0)
        return self._gamepad_axes.get(pad_id, {}).get(axis, 0.0)

    def is_gamepad_pressed(self, pad_id: int = 0, button: str | JoyButton = "a") -> bool:
        """Check if gamepad button is pressed.

        Accepts either a string name (legacy) or JoyButton enum.
        """
        if isinstance(button, JoyButton):
            return int(button) in self._joy_buttons_pressed
        return self._gamepad_buttons.get(pad_id, {}).get(button, False)

    def get_gamepad_vector(self, pad_id: int = 0, stick: str = "left") -> Vec2:
        """Get stick as Vec2 with deadzone applied."""
        x = self.get_gamepad_axis(pad_id, f"{stick}_x")
        y = self.get_gamepad_axis(pad_id, f"{stick}_y")
        v = Vec2(x, y)
        if v.length() < 0.15:
            return Vec2(0, 0)
        return v

    # ----------------------------------------------------------------
    # Public injection (for testing / virtual controls)
    # ----------------------------------------------------------------

    def inject_key(self, key: int | Key, pressed: bool) -> None:
        """Inject a synthetic key event. Same path as platform adapters."""
        self._on_key(int(key), pressed)

    def inject_mouse_button(self, button: int | MouseButton, pressed: bool) -> None:
        """Inject a synthetic mouse button event. Same path as platform adapters."""
        self._on_mouse_button(int(button), pressed)

    # ----------------------------------------------------------------
    # Internal: called by platform adapters
    # ----------------------------------------------------------------

    def _on_key(self, key: int, pressed: bool):
        """Called by platform adapter for typed key events."""
        if pressed:
            if key not in self._keys_pressed:
                self._keys_just_pressed_typed.add(key)
            self._keys_pressed.add(key)
        else:
            self._keys_pressed.discard(key)
            self._keys_just_released_typed.add(key)

    def _on_mouse_button(self, button: int, pressed: bool):
        """Called by platform adapter for typed mouse button events."""
        if pressed:
            if button not in self._mouse_buttons_pressed:
                self._mouse_buttons_just_pressed.add(button)
            self._mouse_buttons_pressed.add(button)
        else:
            self._mouse_buttons_pressed.discard(button)
            self._mouse_buttons_just_released.add(button)
        # Mouse->touch emulation: left button (0) maps to finger 0
        if self._emulate_touch_from_mouse and button == 0:
            x, y = self._mouse_pos
            self._update_touch(0, 0 if pressed else 1, x, y, 1.0 if pressed else 0.0)

    def _on_mouse_move(self, x: float, y: float):
        """Called by platform adapter for mouse movement."""
        old = self._mouse_pos
        self._mouse_pos = (x, y)
        self._mouse_delta = (x - old[0], y - old[1])
        # Mouse->touch emulation: emit move only when finger 0 is "down" (left button held)
        if self._emulate_touch_from_mouse and 0 in self._touches:
            self._update_touch(0, 2, x, y, 1.0)

    def _on_joy_button(self, button: int, pressed: bool):
        """Called by platform adapter for gamepad button events."""
        if pressed:
            if button not in self._joy_buttons_pressed:
                self._joy_buttons_just_pressed.add(button)
            self._joy_buttons_pressed.add(button)
        else:
            self._joy_buttons_pressed.discard(button)
            self._joy_buttons_just_released.add(button)

    def _on_joy_axis(self, axis: int, value: float):
        """Called by platform adapter for gamepad axis changes."""
        self._joy_axes[axis] = value

    def _update_gamepad(self, pad_id: int, buttons: dict[str, bool], axes: dict[str, float]):
        """Called by platform adapter to update gamepad state (legacy string API)."""
        self._gamepad_buttons[pad_id] = buttons
        self._gamepad_axes[pad_id] = axes

    # ----------------------------------------------------------------
    # Touch input
    # ----------------------------------------------------------------

    def _update_touch(self, finger_id: int, action: int, x: float, y: float, pressure: float):
        """Called by platform adapter for touch events. action: 0=down, 1=up, 2=move."""
        if action == 0:  # down
            self._touches[finger_id] = (x, y, pressure)
            self._touches_just_pressed[finger_id] = (x, y, pressure)
        elif action == 1:  # up
            self._touches.pop(finger_id, None)
            self._touches_just_released.add(finger_id)
        elif action == 2:  # move
            self._touches[finger_id] = (x, y, pressure)

    @property
    def touches(self) -> dict[int, tuple[float, float, float]]:
        """Active touches: ``{finger_id: (x, y, pressure)}``."""
        return dict(self._touches)

    @property
    def touches_just_pressed(self) -> dict[int, tuple[float, float, float]]:
        """Touches that started this frame."""
        return dict(self._touches_just_pressed)

    @property
    def touches_just_released(self) -> set[int]:
        """Finger IDs that were lifted this frame."""
        return set(self._touches_just_released)

    @property
    def touch_positions(self) -> list[tuple[int, float, float, float]]:
        """Active touches as a list of ``(finger_id, x, y, pressure)``."""
        return [(fid, x, y, p) for fid, (x, y, p) in self._touches.items()]

    def is_touch_pressed(self, finger_id: int = 0) -> bool:
        """Whether finger_id is currently touching."""
        return finger_id in self._touches

    @property
    def touch_count(self) -> int:
        """Number of active touch points."""
        return len(self._touches)

    def _new_frame(self):
        """Called by engine at frame start. Clears per-frame state."""
        self._keys_just_pressed_typed.clear()
        self._keys_just_released_typed.clear()
        self._mouse_buttons_just_pressed.clear()
        self._mouse_buttons_just_released.clear()
        self._joy_buttons_just_pressed.clear()
        self._joy_buttons_just_released.clear()

    def _end_frame(self):
        """Called by engine at frame end. Clears all per-frame state."""
        self._keys_just_pressed.clear()
        self._keys_just_released.clear()
        self._keys_just_pressed_typed.clear()
        self._keys_just_released_typed.clear()
        self._mouse_buttons_just_pressed.clear()
        self._mouse_buttons_just_released.clear()
        self._joy_buttons_just_pressed.clear()
        self._joy_buttons_just_released.clear()
        self._mouse_delta = (0.0, 0.0)
        self._scroll_delta = (0.0, 0.0)
        self._touches_just_pressed.clear()
        self._touches_just_released.clear()

    def _reset(self):
        """Reset all input state. Useful for testing."""
        self._keys.clear()
        self._keys_just_pressed.clear()
        self._keys_just_released.clear()
        self._mouse_pos = (0.0, 0.0)
        self._mouse_delta = (0.0, 0.0)
        self._scroll_delta = (0.0, 0.0)
        self._gamepad_buttons.clear()
        self._gamepad_axes.clear()
        self._keys_pressed.clear()
        self._keys_just_pressed_typed.clear()
        self._keys_just_released_typed.clear()
        self._mouse_buttons_pressed.clear()
        self._mouse_buttons_just_pressed.clear()
        self._mouse_buttons_just_released.clear()
        self._joy_axes.clear()
        self._joy_buttons_pressed.clear()
        self._joy_buttons_just_pressed.clear()
        self._joy_buttons_just_released.clear()
        self._capture_mode = MouseCaptureMode.VISIBLE
        self._emulate_touch_from_mouse = False
        self._touches.clear()
        self._touches_just_pressed.clear()
        self._touches_just_released.clear()
        self.input_map.clear()

    # ----------------------------------------------------------------
    # Internal: binding resolution helpers
    # ----------------------------------------------------------------

    def _binding_pressed(self, b: InputBinding) -> bool:
        """Check if a typed binding is currently pressed."""
        if b.key is not None:
            return int(b.key) in self._keys_pressed
        if b.mouse_button is not None:
            return int(b.mouse_button) in self._mouse_buttons_pressed
        if b.joy_button is not None:
            return int(b.joy_button) in self._joy_buttons_pressed
        if b.joy_axis is not None:
            val = self._joy_axes.get(int(b.joy_axis), 0.0)
            if b.joy_axis_positive:
                return val > b.deadzone
            return val < -b.deadzone
        return False

    def _binding_just_pressed(self, b: InputBinding) -> bool:
        """Check if a typed binding was just pressed this frame."""
        if b.key is not None:
            return int(b.key) in self._keys_just_pressed_typed
        if b.mouse_button is not None:
            return int(b.mouse_button) in self._mouse_buttons_just_pressed
        if b.joy_button is not None:
            return int(b.joy_button) in self._joy_buttons_just_pressed
        # Axis just-pressed would need previous-frame state; not supported for axes
        return False

    def _binding_just_released(self, b: InputBinding) -> bool:
        """Check if a typed binding was just released this frame."""
        if b.key is not None:
            return int(b.key) in self._keys_just_released_typed
        if b.mouse_button is not None:
            return int(b.mouse_button) in self._mouse_buttons_just_released
        if b.joy_button is not None:
            return int(b.joy_button) in self._joy_buttons_just_released
        return False

    def _binding_strength(self, b: InputBinding) -> float:
        """Return analog strength [0, 1] for a binding."""
        if b.key is not None:
            return 1.0 if int(b.key) in self._keys_pressed else 0.0
        if b.mouse_button is not None:
            return 1.0 if int(b.mouse_button) in self._mouse_buttons_pressed else 0.0
        if b.joy_button is not None:
            return 1.0 if int(b.joy_button) in self._joy_buttons_pressed else 0.0
        if b.joy_axis is not None:
            val = self._joy_axes.get(int(b.joy_axis), 0.0)
            if b.joy_axis_positive:
                return max(0.0, val) if val > b.deadzone else 0.0
            return max(0.0, -val) if val < -b.deadzone else 0.0
        return 0.0

_default_input = _Input()
_active_input: contextvars.ContextVar[_Input] = contextvars.ContextVar("_active_input", default=_default_input)

class _InputProxy:
    """Proxy that delegates all access to the active _Input for the current context.

    Existing code using ``from simvx.core import Input; Input.is_action_pressed(...)``
    continues to work — the proxy transparently routes to whichever _Input instance is
    active in the current context (defaulting to the module-level default).
    """

    __slots__ = ()

    def __getattr__(self, name: str):
        return getattr(_active_input.get(), name)

    def __setattr__(self, name: str, value):
        setattr(_active_input.get(), name, value)

    def __delattr__(self, name: str):
        delattr(_active_input.get(), name)

    def __repr__(self) -> str:
        return repr(_active_input.get())

[docs] @contextmanager def set_active_input(instance: _Input): """Context manager to set the active Input for the current context.""" token = _active_input.set(instance) try: yield finally: _active_input.reset(token)
Input = _InputProxy()