Skip to content

base_component ¤

Base class and decorators for defining JAX-compatible circuit components.

Circuit components are defined as plain Python functions decorated with :func:component or :func:source, which compile them into :class:CircuitComponent subclasses — Equinox modules whose parameters are JAX-traceable leaves. The resulting classes expose two entry points:

  • __call__ — a debug-friendly instance method that accepts port voltages and state values as keyword arguments and returns the physics dicts directly.
  • solver_call — a class method used by the transient solver that operates on flat JAX arrays and a parameter container, and is compatible with jax.vmap and jax.jacfwd.

Example::

@component(ports=("p1", "p2"))
def Resistor(signals: Signals, s: States, R: float = 1.0):
    i = (signals.p1 - signals.p2) / R
    return {"p1": i, "p2": -i}, {}

r = Resistor(R=100.0)
f, q = r(p1=1.0, p2=0.0)

Classes:

Name Description
CircuitComponent

Base class for all JAX-compatible circuit components.

Signals

Protocol representing the port voltage signals passed to a component's physics function.

States

Protocol representing the internal state variables passed to a component's physics function.

Functions:

Name Description
component

Decorator for defining a time-independent circuit component.

source

Decorator for defining a time-dependent circuit component.

CircuitComponent ¤

Bases: Module

Base class for all JAX-compatible circuit components.

Subclasses are not written by hand — they are generated at import time by the :func:component and :func:source decorators, which inspect the decorated function's signature to populate the class variables and wire up the two physics entry points.

Class Variables

ports: Ordered tuple of port names, e.g. ("p1", "p2"). states: Ordered tuple of internal state variable names, e.g. ("i_L",). Empty for purely algebraic components. _uses_time: True for components decorated with :func:source whose physics function accepts a t argument. _VarsType_P: Namedtuple type for unpacking port voltages from a flat array. None if the component has no ports. _VarsType_S: Namedtuple type for unpacking state variables from a flat array. None if the component has no states. _n_ports: Number of ports, cached to avoid repeated len calls in the hot path. _fast_physics: Static closure over the user-defined physics function, compatible with jax.vmap and jax.jacfwd. Signature is (vars_vec, params, t) -> (f_vec, q_vec).

Methods:

Name Description
physics

Raw physics dispatch; overridden by the decorator-generated subclass.

solver_call

Evaluate the component physics (solver entry point).

physics ¤

physics(*args: Any, **kwargs: Any) -> tuple[dict, dict]

Raw physics dispatch; overridden by the decorator-generated subclass.

Source code in circulax/components/base_component.py
def physics(self, *args: Any, **kwargs: Any) -> tuple[dict, dict]:
    """Raw physics dispatch; overridden by the decorator-generated subclass."""
    raise NotImplementedError

solver_call classmethod ¤

solver_call(t: float, y: Array, args: Any) -> tuple[Array, Array]

Evaluate the component physics (solver entry point).

Thin wrapper around the static _fast_physics closure. Called by the transient solver inside jax.vmap across all instances in a component group, and differentiated via jax.jacfwd to assemble the system Jacobian.

Parameters:

Name Type Description Default
t float

Current simulation time.

required
y Array

Flat state vector of shape (n_ports + n_states,) containing port voltages followed by state variable values.

required
args Any

Parameter container for this instance. May be a dict {"R": 100.0} or an object (e.g. the component instance itself) whose attributes match the parameter names. Must not be a raw scalar.

required

Returns:

Type Description
Array

A two-tuple (f_vec, q_vec) of JAX arrays, each of shape

Array

(n_ports + n_states,), containing the resistive and reactive

tuple[Array, Array]

contributions for every port and state variable.

Source code in circulax/components/base_component.py
@classmethod
def solver_call(
    cls,
    t: float,
    y: jax.Array,
    args: Any,
) -> tuple[jax.Array, jax.Array]:
    """Evaluate the component physics (solver entry point).

    Thin wrapper around the static ``_fast_physics`` closure. Called by
    the transient solver inside ``jax.vmap`` across all instances in a
    component group, and differentiated via ``jax.jacfwd`` to assemble
    the system Jacobian.

    Args:
        t: Current simulation time.
        y: Flat state vector of shape ``(n_ports + n_states,)`` containing
            port voltages followed by state variable values.
        args: Parameter container for this instance. May be a dict
            ``{"R": 100.0}`` or an object (e.g. the component instance
            itself) whose attributes match the parameter names. Must not
            be a raw scalar.

    Returns:
        A two-tuple ``(f_vec, q_vec)`` of JAX arrays, each of shape
        ``(n_ports + n_states,)``, containing the resistive and reactive
        contributions for every port and state variable.

    """
    return cls._fast_physics(y, args, t)

Signals ¤

Bases: Protocol

Protocol representing the port voltage signals passed to a component's physics function.

Attributes are accessed by port name (e.g. signals.p1), backed by a namedtuple constructed from the component's ports declaration.

States ¤

Bases: Protocol

Protocol representing the internal state variables passed to a component's physics function.

Attributes are accessed by state name (e.g. s.i_L), backed by a namedtuple constructed from the component's states declaration.

component ¤

component(ports: tuple[str, ...] = (), states: tuple[str, ...] = ()) -> Any

Decorator for defining a time-independent circuit component.

Compiles the decorated physics function into a :class:CircuitComponent subclass. The function must begin with (signals, s) followed by any number of parameters with defaults, which become JAX-traceable Equinox fields on the resulting class.

Parameters:

Name Type Description Default
ports tuple[str, ...]

Ordered tuple of port names. Must match the connection keys used in the netlist.

()
states tuple[str, ...]

Ordered tuple of internal state variable names. State variables are appended to the solver's state vector after the node voltages.

()

Returns:

Type Description
Any

A decorator that accepts a physics function and returns a

Any

class:CircuitComponent subclass.

Example::

@component(ports=("p1", "p2"))
def Resistor(signals: Signals, s: States, R: float = 1.0):
    i = (signals.p1 - signals.p2) / R
    return {"p1": i, "p2": -i}, {}
Source code in circulax/components/base_component.py
def component(
    ports: tuple[str, ...] = (),
    states: tuple[str, ...] = (),
) -> Any:
    """Decorator for defining a time-independent circuit component.

    Compiles the decorated physics function into a :class:`CircuitComponent`
    subclass. The function must begin with ``(signals, s)`` followed by any
    number of parameters with defaults, which become JAX-traceable Equinox
    fields on the resulting class.

    Args:
        ports: Ordered tuple of port names. Must match the connection keys
            used in the netlist.
        states: Ordered tuple of internal state variable names. State
            variables are appended to the solver's state vector after the
            node voltages.

    Returns:
        A decorator that accepts a physics function and returns a
        :class:`CircuitComponent` subclass.

    Example::

        @component(ports=("p1", "p2"))
        def Resistor(signals: Signals, s: States, R: float = 1.0):
            i = (signals.p1 - signals.p2) / R
            return {"p1": i, "p2": -i}, {}

    """
    return lambda fn: _build_component(fn, ports, states, uses_time=False)

source ¤

source(ports: tuple[str, ...] = (), states: tuple[str, ...] = ()) -> Any

Decorator for defining a time-dependent circuit component.

Identical to :func:component except the decorated physics function must accept t as its third argument (after signals and s), and may use it to implement time-varying behaviour such as sinusoidal sources or delayed step functions.

Parameters:

Name Type Description Default
ports tuple[str, ...]

Ordered tuple of port names.

()
states tuple[str, ...]

Ordered tuple of internal state variable names.

()

Returns:

Type Description
Any

A decorator that accepts a physics function and returns a

Any

class:CircuitComponent subclass.

Example::

@source(ports=("p1", "p2"), states=("i_src",))
def VoltageSource(signals: Signals, s: States, t: float, V: float = 1.0):
    constraint = (signals.p1 - signals.p2) - V
    return {"p1": s.i_src, "p2": -s.i_src, "i_src": constraint}, {}
Source code in circulax/components/base_component.py
def source(
    ports: tuple[str, ...] = (),
    states: tuple[str, ...] = (),
) -> Any:
    """Decorator for defining a time-dependent circuit component.

    Identical to :func:`component` except the decorated physics function
    must accept ``t`` as its third argument (after ``signals`` and ``s``),
    and may use it to implement time-varying behaviour such as sinusoidal
    sources or delayed step functions.

    Args:
        ports: Ordered tuple of port names.
        states: Ordered tuple of internal state variable names.

    Returns:
        A decorator that accepts a physics function and returns a
        :class:`CircuitComponent` subclass.

    Example::

        @source(ports=("p1", "p2"), states=("i_src",))
        def VoltageSource(signals: Signals, s: States, t: float, V: float = 1.0):
            constraint = (signals.p1 - signals.p2) - V
            return {"p1": s.i_src, "p2": -s.i_src, "i_src": constraint}, {}

    """
    return lambda fn: _build_component(fn, ports, states, uses_time=True)