Skip to content

solvers ¤

Root finding and transient solvers.

Modules:

Name Description
assembly

Assembly functions for the transient circuit solver.

linear

Circuit Linear Solvers Strategy Pattern.

transient

Transient solvers to be used with Diffrax.

Classes:

Name Description
CircuitLinearSolver

Abstract Base Class for all circuit linear solvers.

DenseSolver

Solves the system using dense matrix factorization (LU).

KLUSolver

Solves the system using the KLU sparse solver (via klujax).

SparseSolver

Solves the system using JAX's Iterative BiCGStab solver.

VectorizedTransientSolver

Transient solver that works strictly on FLAT (Real) vectors.

Functions:

Name Description
analyze_circuit

Initializes a linear solver strategy for circuit analysis.

assemble_system_complex

Assemble the residual vectors and effective Jacobian values for an unrolled complex system.

assemble_system_real

Assemble the residual vectors and effective Jacobian values for a real system.

setup_transient

Configures and returns a function for executing transient analysis.

CircuitLinearSolver ¤

Bases: AbstractLinearSolver

Abstract Base Class for all circuit linear solvers.

This class provides the unified interface for: 1. Storing static matrix structure (indices, rows, cols). 2. Handling Real vs. Complex-Unrolled system configurations. 3. Providing a robust Newton-Raphson DC Operating Point solver.

Attributes:

Name Type Description
ground_indices Array

Indices of nodes connected to ground (forced to 0V).

is_complex bool

Static flag. If True, the system is 2N x 2N (Real/Imag unrolled). If False, the system is N x N (Real).

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

DenseSolver ¤

Bases: CircuitLinearSolver

Solves the system using dense matrix factorization (LU).

Best For
  • Small to Medium circuits (N < 2000).
  • Wavelength sweeps (AC Analysis) on GPU.
  • Systems where VMAP parallelism is critical.

Attributes:

Name Type Description
static_rows Array

Row indices for placing values into dense matrix.

static_cols Array

Column indices.

g_leak float

Leakage conductance added to diagonal to prevent singularity.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

from_component_groups

Factory method to pre-calculate indices for the dense matrix.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> DenseSolver

Factory method to pre-calculate indices for the dense matrix.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "DenseSolver":
    """Factory method to pre-calculate indices for the dense matrix."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(
        component_groups, num_vars, is_complex
    )
    return cls(
        static_rows=jnp.array(rows),
        static_cols=jnp.array(cols),
        sys_size=sys_size,
        ground_indices=jnp.array(ground_idxs),
        is_complex=is_complex,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

KLUSolver ¤

Bases: CircuitLinearSolver

Solves the system using the KLU sparse solver (via klujax).

Best For
  • Large circuits (N > 5000) running on CPU.
  • DC Operating Points of massive meshes.
  • Cases where DenseSolver runs out of memory (OOM).
Note

Does NOT support vmap (batching) automatically.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

from_component_groups

Factory method to pre-hash indices for sparse coalescence.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> KLUSolver

Factory method to pre-hash indices for sparse coalescence.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "KLUSolver":
    """Factory method to pre-hash indices for sparse coalescence."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(
        component_groups, num_vars, is_complex
    )
    u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
    return cls(
        u_rows=jnp.array(u_rows),
        u_cols=jnp.array(u_cols),
        map_idx=jnp.array(map_idx),
        n_unique=n_unique,
        ground_indices=jnp.array(ground_idxs),
        sys_size=sys_size,
        is_complex=is_complex,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

SparseSolver ¤

Bases: CircuitLinearSolver

Solves the system using JAX's Iterative BiCGStab solver.

Best For
  • Large Transient Simulations on GPU (uses previous step as warm start).
  • Systems where N is too large for Dense, but we need VMAP support.

Attributes:

Name Type Description
diag_mask Array

Mask to extract diagonal elements for preconditioning.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

from_component_groups

Factory method to prepare indices and diagonal mask.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> SparseSolver

Factory method to prepare indices and diagonal mask.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "SparseSolver":
    """Factory method to prepare indices and diagonal mask."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(
        component_groups, num_vars, is_complex
    )
    return cls(
        static_rows=jnp.array(rows),
        static_cols=jnp.array(cols),
        diag_mask=jnp.array(rows == cols),
        sys_size=sys_size,
        ground_indices=jnp.array(ground_idxs),
        is_complex=is_complex,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

VectorizedTransientSolver ¤

Bases: AbstractSolver

Transient solver that works strictly on FLAT (Real) vectors.

Delegates complexity handling to the 'linear_solver' strategy.

analyze_circuit ¤

analyze_circuit(
    groups: list, num_vars: int, backend: str = "default", *, is_complex: bool = False
) -> CircuitLinearSolver

Initializes a linear solver strategy for circuit analysis.

This function serves as a factory and wrapper to select and configure the appropriate numerical backend for solving the linear system of equations derived from a circuit's topology.

Parameters:

Name Type Description Default
groups list

A list of component groups that define the circuit's structure and properties.

required
num_vars int

The total number of variables in the linear system.

required
backend str

The name of the solver backend to use. Supported backends are 'klu', 'klu_split', 'dense', and 'sparse'. Defaults to 'default', which uses the 'klu' solver.

'default'
is_complex bool

A flag indicating whether the circuit analysis involves complex numbers. Defaults to False.

False

Returns:

Name Type Description
CircuitLinearSolver CircuitLinearSolver

An instance of a circuit linear solver strategy

CircuitLinearSolver

configured for the specified backend and circuit parameters.

Raises:

Type Description
ValueError

If the specified backend is not supported.

Source code in circulax/solvers/linear.py
def analyze_circuit(
    groups: list, num_vars: int, backend: str = "default", *, is_complex: bool = False
) -> CircuitLinearSolver:
    """Initializes a linear solver strategy for circuit analysis.

    This function serves as a factory and wrapper to select and configure the
    appropriate numerical backend for solving the linear system of equations
    derived from a circuit's topology.

    Args:
        groups (list): A list of component groups that define the circuit's
            structure and properties.
        num_vars (int): The total number of variables in the linear system.
        backend (str, optional): The name of the solver backend to use.
            Supported backends are 'klu', 'klu_split', 'dense', and 'sparse'.
            Defaults to 'default', which uses the 'klu' solver.
        is_complex (bool, optional): A flag indicating whether the circuit
            analysis involves complex numbers. Defaults to False.

    Returns:
        CircuitLinearSolver: An instance of a circuit linear solver strategy
        configured for the specified backend and circuit parameters.

    Raises:
        ValueError: If the specified backend is not supported.

    """
    solver_class = backends.get(backend)
    if solver_class is None:
        msg = (
            f"Unknown backend: '{backend}'. "
            f"Available backends are {list(backends.keys())}"
        )
        raise ValueError(
            msg
        )

    linear_strategy = solver_class.from_component_groups(groups, num_vars, is_complex=is_complex)

    return linear_strategy

assemble_system_complex ¤

assemble_system_complex(
    y_guess: Array, component_groups: dict, t1: float, dt: float
) -> tuple[Array, Array, Array]

Assemble the residual vectors and effective Jacobian values for an unrolled complex system.

The complex state vector is stored in unrolled (block) format: the first half of y_guess holds the real parts of all node voltages/states, the second half holds the imaginary parts. This avoids JAX's limited support for complex-valued sparse linear solvers by keeping all arithmetic real.

The Jacobian is split into four real blocks — RR, RI, IR, II — representing the partial derivatives of the real and imaginary residual components with respect to the real and imaginary state components respectively. The blocks are concatenated in RR→RI→IR→II order to match the sparsity index layout produced during compilation.

Parameters:

Name Type Description Default
y_guess Array

Unrolled state vector of shape (2 * num_vars,), where y_guess[:num_vars] are real parts and y_guess[num_vars:] are imaginary parts.

required
component_groups dict

Compiled component groups returned by :func:compile_netlist, keyed by group name.

required
t1 float

Time at which the system is being evaluated.

required
dt float

Timestep duration, used to scale the reactive Jacobian blocks.

required

Returns:

Type Description
Array

A three-tuple (total_f, total_q, jac_vals) where:

Array
  • total_f — assembled resistive residual in unrolled format, shape (2 * num_vars,).
Array
  • total_q — assembled reactive residual in unrolled format, shape (2 * num_vars,).
tuple[Array, Array, Array]
  • jac_vals — concatenated non-zero values of the four effective Jacobian blocks (RR, RI, IR, II) in group-sorted order.
Source code in circulax/solvers/assembly.py
def assemble_system_complex(
    y_guess: Array,
    component_groups: dict,
    t1: float,
    dt: float,
) -> tuple[Array, Array, Array]:
    """Assemble the residual vectors and effective Jacobian values for an unrolled complex system.

    The complex state vector is stored in unrolled (block) format: the first
    half of ``y_guess`` holds the real parts of all node voltages/states, the
    second half holds the imaginary parts. This avoids JAX's limited support
    for complex-valued sparse linear solvers by keeping all arithmetic real.

    The Jacobian is split into four real blocks — RR, RI, IR, II — representing
    the partial derivatives of the real and imaginary residual components with
    respect to the real and imaginary state components respectively. The blocks
    are concatenated in RR→RI→IR→II order to match the sparsity index layout
    produced during compilation.

    Args:
        y_guess: Unrolled state vector of shape ``(2 * num_vars,)``, where
            ``y_guess[:num_vars]`` are real parts and ``y_guess[num_vars:]``
            are imaginary parts.
        component_groups: Compiled component groups returned by
            :func:`compile_netlist`, keyed by group name.
        t1: Time at which the system is being evaluated.
        dt: Timestep duration, used to scale the reactive Jacobian blocks.

    Returns:
        A three-tuple ``(total_f, total_q, jac_vals)`` where:

        - **total_f** — assembled resistive residual in unrolled format,
            shape ``(2 * num_vars,)``.
        - **total_q** — assembled reactive residual in unrolled format,
            shape ``(2 * num_vars,)``.
        - **jac_vals** — concatenated non-zero values of the four effective
            Jacobian blocks (RR, RI, IR, II) in group-sorted order.

    """
    sys_size = y_guess.shape[0]
    half_size = sys_size // 2
    y_real, y_imag = y_guess[:half_size], y_guess[half_size:]

    total_f = jnp.zeros(sys_size, dtype=jnp.float64)
    total_q = jnp.zeros(sys_size, dtype=jnp.float64)

    vals_blocks: list[list[Array]] = [[], [], [], []]

    for k in sorted(component_groups.keys()):
        group = component_groups[k]
        v_r, v_i = y_real[group.var_indices], y_imag[group.var_indices]

        physics_split = functools.partial(_complex_physics, group=group, t1=t1)

        (fr, fi, qr, qi), (dfr_r, dfi_r, dqr_r, dqi_r), (dfr_i, dfi_i, dqr_i, dqi_i) = (
            jax.vmap(functools.partial(_primal_and_jac_complex, physics_split))(
                v_r, v_i, group.params
            )
        )

        idx_r, idx_i = group.eq_indices, group.eq_indices + half_size
        total_f = total_f.at[idx_r].add(fr).at[idx_i].add(fi)
        total_q = total_q.at[idx_r].add(qr).at[idx_i].add(qi)

        vals_blocks[0].append((dfr_r + dqr_r / dt).reshape(-1))  # RR
        vals_blocks[1].append((dfr_i + dqr_i / dt).reshape(-1))  # RI
        vals_blocks[2].append((dfi_r + dqi_r / dt).reshape(-1))  # IR
        vals_blocks[3].append((dfi_i + dqi_i / dt).reshape(-1))  # II

    all_vals = jnp.concatenate([jnp.concatenate(b) for b in vals_blocks])
    return total_f, total_q, all_vals

assemble_system_real ¤

assemble_system_real(
    y_guess: Array, component_groups: dict, t1: float, dt: float
) -> tuple[Array, Array, Array]

Assemble the residual vectors and effective Jacobian values for a real system.

For each component group, evaluates the physics at t1 and computes the forward-mode Jacobian via jax.jacfwd. The effective Jacobian combines the resistive and reactive contributions as J_eff = df/dy + (1/dt) * dq/dy, consistent with the implicit trapezoidal discretisation used by the solver.

Components are processed in sorted key order to ensure a deterministic non-zero layout in the sparse Jacobian, which is required for the factorisation step.

Parameters:

Name Type Description Default
y_guess Array

Current state vector of shape (sys_size,).

required
component_groups dict

Compiled component groups returned by :func:compile_netlist, keyed by group name.

required
t1 float

Time at which the system is being evaluated.

required
dt float

Timestep duration, used to scale the reactive Jacobian block.

required

Returns:

Type Description
Array

A three-tuple (total_f, total_q, jac_vals) where:

Array
  • total_f — assembled resistive residual, shape (sys_size,).
Array
  • total_q — assembled reactive residual, shape (sys_size,).
tuple[Array, Array, Array]
  • jac_vals — concatenated non-zero values of the effective Jacobian in group-sorted order, ready to be passed to the sparse linear solver.
Source code in circulax/solvers/assembly.py
def assemble_system_real(
    y_guess: Array,
    component_groups: dict,
    t1: float,
    dt: float,
) -> tuple[Array, Array, Array]:
    """Assemble the residual vectors and effective Jacobian values for a real system.

    For each component group, evaluates the physics at ``t1`` and computes the
    forward-mode Jacobian via ``jax.jacfwd``. The effective Jacobian combines
    the resistive and reactive contributions as ``J_eff = df/dy + (1/dt) * dq/dy``,
    consistent with the implicit trapezoidal discretisation used by the solver.

    Components are processed in sorted key order to ensure a deterministic
    non-zero layout in the sparse Jacobian, which is required for the
    factorisation step.

    Args:
        y_guess: Current state vector of shape ``(sys_size,)``.
        component_groups: Compiled component groups returned by
            :func:`compile_netlist`, keyed by group name.
        t1: Time at which the system is being evaluated.
        dt: Timestep duration, used to scale the reactive Jacobian block.

    Returns:
        A three-tuple ``(total_f, total_q, jac_vals)`` where:

        - **total_f** — assembled resistive residual, shape ``(sys_size,)``.
        - **total_q** — assembled reactive residual, shape ``(sys_size,)``.
        - **jac_vals** — concatenated non-zero values of the effective Jacobian
            in group-sorted order, ready to be passed to the sparse linear solver.

    """
    sys_size = y_guess.shape[0]
    total_f = jnp.zeros(sys_size, dtype=y_guess.dtype)
    total_q = jnp.zeros(sys_size, dtype=y_guess.dtype)
    vals_list = []

    for k in sorted(component_groups.keys()):
        group = component_groups[k]
        v_locs = y_guess[group.var_indices]

        physics_at_t1 = functools.partial(_real_physics, group=group, t1=t1)

        (f_l, q_l), (df_l, dq_l) = jax.vmap(
            functools.partial(_primal_and_jac_real, physics_at_t1)
        )(v_locs, group.params)

        total_f = total_f.at[group.eq_indices].add(f_l)
        total_q = total_q.at[group.eq_indices].add(q_l)
        j_eff = df_l + (dq_l / dt)
        vals_list.append(j_eff.reshape(-1))

    return total_f, total_q, jnp.concatenate(vals_list)

setup_transient ¤

setup_transient(
    groups: list,
    linear_strategy: CircuitLinearSolver,
    transient_solver: AbstractSolver = None,
) -> Callable[..., Solution]

Configures and returns a function for executing transient analysis.

This function acts as a factory, preparing a transient solver that is pre-configured with the circuit's linear strategy. It returns a callable that executes the time-domain simulation using diffrax.diffeqsolve.

Parameters:

Name Type Description Default
groups list

A list of component groups that define the circuit.

required
linear_strategy CircuitLinearSolver

The configured linear solver strategy, typically obtained from analyze_circuit.

required
transient_solver optional

The transient solver class to use. If None, VectorizedTransientSolver will be used.

None

Returns:

Type Description
Callable[..., Solution]

Callable[..., Any]: A function that executes the transient analysis.

Callable[..., Solution]

This returned function accepts the following arguments:

t0 (float): The start time of the simulation. t1 (float): The end time of the simulation. dt0 (float): The initial time step for the solver. y0 (ArrayLike): The initial state vector of the system. saveat (diffrax.SaveAt, optional): Specifies time points at which to save the solution. Defaults to None. max_steps (int, optional): The maximum number of steps the solver can take. Defaults to 100000. throw (bool, optional): If True, the solver will raise an error on failure. Defaults to False. term (diffrax.AbstractTerm, optional): The term defining the ODE. Defaults to a zero-value ODETerm. stepsize_controller (diffrax.AbstractStepSizeController, optional): The step size controller. Defaults to ConstantStepSize(). **kwargs: Additional keyword arguments to pass directly to diffrax.diffeqsolve.

Source code in circulax/solvers/transient.py
def setup_transient(
    groups: list,
    linear_strategy: CircuitLinearSolver,
    transient_solver:AbstractSolver=None
) -> Callable[..., diffrax.Solution]:
    """Configures and returns a function for executing transient analysis.

    This function acts as a factory, preparing a transient solver that is
    pre-configured with the circuit's linear strategy. It returns a callable
    that executes the time-domain simulation using `diffrax.diffeqsolve`.

    Args:
        groups (list): A list of component groups that define the circuit.
        linear_strategy (CircuitLinearSolver): The configured linear solver
            strategy, typically obtained from `analyze_circuit`.
        transient_solver (optional): The transient solver class to use.
            If None, `VectorizedTransientSolver` will be used.

    Returns:
        Callable[..., Any]: A function that executes the transient analysis.
        This returned function accepts the following arguments:

            t0 (float): The start time of the simulation.
            t1 (float): The end time of the simulation.
            dt0 (float): The initial time step for the solver.
            y0 (ArrayLike): The initial state vector of the system.
            saveat (diffrax.SaveAt, optional): Specifies time points at which
                to save the solution. Defaults to None.
            max_steps (int, optional): The maximum number of steps the solver
                can take. Defaults to 100000.
            throw (bool, optional): If True, the solver will raise an error on
                failure. Defaults to False.
            term (diffrax.AbstractTerm, optional): The term defining the ODE.
                Defaults to a zero-value ODETerm.
            stepsize_controller (diffrax.AbstractStepSizeController, optional):
                The step size controller. Defaults to `ConstantStepSize()`.
            **kwargs: Additional keyword arguments to pass directly to
                `diffrax.diffeqsolve`.

    """
    if transient_solver is None:
        transient_solver = VectorizedTransientSolver

    tsolver = transient_solver(linear_solver=linear_strategy)

    sys_size = (
        linear_strategy.sys_size // 2
        if linear_strategy.is_complex
        else linear_strategy.sys_size
    )

    def _execute_transient(
        *,
        t0: float,
        t1: float,
        dt0: float,
        y0: ArrayLike,
        saveat: diffrax.SaveAt = None,
        max_steps: int = 100000,
        throw: bool = False,
        **kwargs: Any,
    ) -> diffrax.Solution:
        """Executes the transient simulation for the pre-configured circuit."""
        term = kwargs.pop("term", diffrax.ODETerm(lambda t, y, args: jnp.zeros_like(y)))
        solver = kwargs.pop("solver", tsolver)
        args = kwargs.pop("args", (groups, sys_size))
        stepsize_controller = kwargs.pop("stepsize_controller", ConstantStepSize())

        sol = diffrax.diffeqsolve(
            terms=term,
            solver=solver,
            t0=t0,
            t1=t1,
            dt0=dt0,
            y0=y0,
            args=args,
            saveat=saveat,
            max_steps=max_steps,
            throw=throw,
            stepsize_controller=stepsize_controller,
            **kwargs,
        )

        return sol

    return _execute_transient