Skip to content

linear ¤

Circuit Linear Solvers Strategy Pattern.

This module defines the linear algebra strategies used by the circuit simulator. It leverages the lineax abstract base class to provide interchangeable solvers that work seamlessly with JAX transformations (JIT, VMAP, GRAD).

Architecture¤

The core idea is to separate the physics assembly (calculating Jacobian values) from the linear solve (inverting the Jacobian).

Classes:

Name Description
CircuitLinearSolver

Abstract base defining the interface and common DC logic.

DenseSolver

Uses JAX's native dense solver (LU decomposition). Best for small circuits (N < 2000) & GPU.

KLUSolver

Uses the KLU sparse solver (via klujax). Best for large circuits on CPU.

SparseSolver

Uses JAX's iterative BiCGStab. Best for large transient simulations on GPU.

Functions:

Name Description
analyze_circuit

Initializes a linear solver strategy for circuit analysis.

Attributes:

Name Type Description
DAMPING_EPS float

Small additive epsilon that prevents division by zero in the damping formula.

DAMPING_FACTOR float

Newton-step damping coefficient: limits each step to at most DAMPING_FACTOR / |δy|_max.

DC_DT float

Effective timestep used for DC analysis; makes capacitor stamps vanish (C/dt → 0).

GROUND_STIFFNESS float

Penalty added to ground-node diagonal entries to enforce V=0.

DAMPING_EPS module-attribute ¤

DAMPING_EPS: float = 1e-09

Small additive epsilon that prevents division by zero in the damping formula.

DAMPING_FACTOR module-attribute ¤

DAMPING_FACTOR: float = 0.5

Newton-step damping coefficient: limits each step to at most DAMPING_FACTOR / |δy|_max.

DC_DT module-attribute ¤

DC_DT: float = 1e+18

Effective timestep used for DC analysis; makes capacitor stamps vanish (C/dt → 0).

GROUND_STIFFNESS module-attribute ¤

GROUND_STIFFNESS: float = 1000000000.0

Penalty added to ground-node diagonal entries to enforce V=0.

CircuitLinearSolver ¤

Bases: AbstractLinearSolver

Abstract Base Class for all circuit linear solvers.

This class provides the unified interface for: 1. Storing static matrix structure (indices, rows, cols). 2. Handling Real vs. Complex-Unrolled system configurations. 3. Providing a robust Newton-Raphson DC Operating Point solver.

Attributes:

Name Type Description
ground_indices Array

Indices of nodes connected to ground (forced to 0V).

is_complex bool

Static flag. If True, the system is 2N x 2N (Real/Imag unrolled). If False, the system is N x N (Real).

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

DenseSolver ¤

Bases: CircuitLinearSolver

Solves the system using dense matrix factorization (LU).

Best For
  • Small to Medium circuits (N < 2000).
  • Wavelength sweeps (AC Analysis) on GPU.
  • Systems where VMAP parallelism is critical.

Attributes:

Name Type Description
static_rows Array

Row indices for placing values into dense matrix.

static_cols Array

Column indices.

g_leak float

Leakage conductance added to diagonal to prevent singularity.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

from_component_groups

Factory method to pre-calculate indices for the dense matrix.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> DenseSolver

Factory method to pre-calculate indices for the dense matrix.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "DenseSolver":
    """Factory method to pre-calculate indices for the dense matrix."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(
        component_groups, num_vars, is_complex
    )
    return cls(
        static_rows=jnp.array(rows),
        static_cols=jnp.array(cols),
        sys_size=sys_size,
        ground_indices=jnp.array(ground_idxs),
        is_complex=is_complex,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

KLUSolver ¤

Bases: CircuitLinearSolver

Solves the system using the KLU sparse solver (via klujax).

Best For
  • Large circuits (N > 5000) running on CPU.
  • DC Operating Points of massive meshes.
  • Cases where DenseSolver runs out of memory (OOM).
Note

Does NOT support vmap (batching) automatically.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

from_component_groups

Factory method to pre-hash indices for sparse coalescence.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> KLUSolver

Factory method to pre-hash indices for sparse coalescence.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "KLUSolver":
    """Factory method to pre-hash indices for sparse coalescence."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(
        component_groups, num_vars, is_complex
    )
    u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
    return cls(
        u_rows=jnp.array(u_rows),
        u_cols=jnp.array(u_cols),
        map_idx=jnp.array(map_idx),
        n_unique=n_unique,
        ground_indices=jnp.array(ground_idxs),
        sys_size=sys_size,
        is_complex=is_complex,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

KLUSplitFactorSolver ¤

Bases: KLUSplitSolver

Solves the system using the KLU sparse solver (via klujax) with split interface.

This solver performs symbolic analysis ONCE during initialization and reuses the symbolic handle for subsequent solves, significantly speeding up non-linear simulations (Newton-Raphson iterations). This version of the solver is further enhanced but calculting the numeric part of the KLU solution only once

Best For
  • Large circuits (N > 5000) running on CPU.
  • DC Operating Points of massive meshes.

Attributes:

Name Type Description
Bp, Bi

CSC format indices (fixed structure).

csc_map_idx Bi

Mapping from raw value indices to CSC value vector.

symbolic_handle Bi

Pointer to the pre-computed KLU symbolic analysis.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

factor_jacobian

Factor the Jacobian and return numeric handle.

from_component_groups

Factory method to pre-hash indices for sparse coalescence.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

solve_with_frozen_jacobian

Solve using pre-computed numeric factorization (for frozen Jacobian Newton).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

factor_jacobian ¤

factor_jacobian(all_vals: Array) -> Array

Factor the Jacobian and return numeric handle.

Source code in circulax/solvers/linear.py
def factor_jacobian(self, all_vals: jax.Array) -> jax.Array:
    """Factor the Jacobian and return numeric handle."""
    g_vals = jnp.full(self.ground_indices.shape[0], GROUND_STIFFNESS, dtype=all_vals.dtype)
    l_vals = jnp.full(self.sys_size, self.g_leak, dtype=all_vals.dtype)

    raw_vals = jnp.concatenate([all_vals, g_vals, l_vals])
    coalesced_vals = jax.ops.segment_sum(
        raw_vals, self.map_idx, num_segments=self.n_unique
    )

    return klujax.factor(
        self.u_rows, self.u_cols, coalesced_vals, self.symbolic_handle
    )

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> KLUSplitSolver

Factory method to pre-hash indices for sparse coalescence.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "KLUSplitSolver":
    """Factory method to pre-hash indices for sparse coalescence."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(
        component_groups, num_vars, is_complex
    )
    u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
    symbolic = klujax.analyze(u_rows, u_cols, sys_size)
    return cls(
        u_rows=jnp.array(u_rows),
        u_cols=jnp.array(u_cols),
        map_idx=jnp.array(map_idx),
        n_unique=n_unique,
        _handle_wrapper=symbolic,
        ground_indices=jnp.array(ground_idxs),
        sys_size=sys_size,
        is_complex=is_complex,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

solve_with_frozen_jacobian ¤

solve_with_frozen_jacobian(residual: Array, numeric: Array) -> Solution

Solve using pre-computed numeric factorization (for frozen Jacobian Newton).

Source code in circulax/solvers/linear.py
def solve_with_frozen_jacobian(
    self, residual: jax.Array, numeric: jax.Array
) -> lx.Solution:
    """Solve using pre-computed numeric factorization (for frozen Jacobian Newton)."""
    solution = klujax.solve_with_numeric(
        numeric, residual, self._handle_wrapper.handle
    )
    return lx.Solution(
        value=solution.reshape(residual.shape),
        result=lx.RESULTS.successful,
        state=None,
        stats={},
    )

KLUSplitSolver ¤

Bases: CircuitLinearSolver

Solves the system using the KLU sparse solver (via klujax) with split interface.

This solver performs symbolic analysis ONCE during initialization and reuses the symbolic handle for subsequent solves, significantly speeding up non-linear simulations (Newton-Raphson iterations).

Best For
  • Large circuits (N > 5000) running on CPU.
  • DC Operating Points of massive meshes.

Attributes:

Name Type Description
Bp, Bi

CSC format indices (fixed structure).

csc_map_idx Bi

Mapping from raw value indices to CSC value vector.

symbolic_handle Bi

Pointer to the pre-computed KLU symbolic analysis.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

from_component_groups

Factory method to pre-hash indices for sparse coalescence.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> KLUSplitSolver

Factory method to pre-hash indices for sparse coalescence.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "KLUSplitSolver":
    """Factory method to pre-hash indices for sparse coalescence."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(
        component_groups, num_vars, is_complex
    )
    u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
    symbolic = klujax.analyze(u_rows, u_cols, sys_size)
    return cls(
        u_rows=jnp.array(u_rows),
        u_cols=jnp.array(u_cols),
        map_idx=jnp.array(map_idx),
        n_unique=n_unique,
        _handle_wrapper=symbolic,
        ground_indices=jnp.array(ground_idxs),
        sys_size=sys_size,
        is_complex=is_complex,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

KlursSplitSolver ¤

Bases: KLUSplitSolver

Solves the system using the rust wrapped KLU sparse solver (via klu-rs) with split interface.

This solver performs symbolic analysis ONCE during initialization and reuses the symbolic handle for subsequent solves, significantly speeding up non-linear simulations (Newton-Raphson iterations).

Best For
  • Large circuits (N > 5000) running on CPU.
  • DC Operating Points of massive meshes.

Attributes:

Name Type Description
Bp, Bi

CSC format indices (fixed structure).

csc_map_idx Bi

Mapping from raw value indices to CSC value vector.

symbolic_handle Bi

Pointer to the pre-computed KLU symbolic analysis.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

from_component_groups

Factory method to pre-hash indices for sparse coalescence.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> KlursSplitSolver

Factory method to pre-hash indices for sparse coalescence.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "KlursSplitSolver":
    """Factory method to pre-hash indices for sparse coalescence."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(
        component_groups, num_vars, is_complex
    )
    u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
    symbol = klurs.analyze(u_rows, u_cols, sys_size)
    return cls(
        u_rows=jnp.array(u_rows),
        u_cols=jnp.array(u_cols),
        map_idx=jnp.array(map_idx),
        n_unique=n_unique,
        _handle_wrapper=symbol,
        ground_indices=jnp.array(ground_idxs),
        sys_size=sys_size,
        is_complex=is_complex,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

SparseSolver ¤

Bases: CircuitLinearSolver

Solves the system using JAX's Iterative BiCGStab solver.

Best For
  • Large Transient Simulations on GPU (uses previous step as warm start).
  • Systems where N is too large for Dense, but we need VMAP support.

Attributes:

Name Type Description
diag_mask Array

Mask to extract diagonal elements for preconditioning.

Methods:

Name Description
assume_full_rank

Indicate if the solver assumes the operator is full rank.

compute

Performs the computation of the component for each step.

from_component_groups

Factory method to prepare indices and diagonal mask.

init

Initialize the solver state (No-op for stateless solvers).

solve_dc

Performs a robust DC Operating Point analysis (Newton-Raphson).

assume_full_rank ¤

assume_full_rank() -> bool

Indicate if the solver assumes the operator is full rank.

Source code in circulax/solvers/linear.py
def assume_full_rank(self) -> bool:
    """Indicate if the solver assumes the operator is full rank."""
    return False

compute ¤

compute(state: Any, vector: Array, options: Any) -> Solution

Performs the computation of the component for each step.

In our case, we usually call _solve_impl directly to avoid overhead, but this satisfies the API.

Source code in circulax/solvers/linear.py
def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
    """Performs the computation of the component for each step.

    In our case, we usually call `_solve_impl` directly to avoid overhead,
    but this satisfies the API.

    """
    msg = "Directly call _solve_impl for internal use."
    raise NotImplementedError(msg)

from_component_groups classmethod ¤

from_component_groups(
    component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> SparseSolver

Factory method to prepare indices and diagonal mask.

Source code in circulax/solvers/linear.py
@classmethod
def from_component_groups(
    cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "SparseSolver":
    """Factory method to prepare indices and diagonal mask."""
    rows, cols, ground_idxs, sys_size = _build_index_arrays(
        component_groups, num_vars, is_complex
    )
    return cls(
        static_rows=jnp.array(rows),
        static_cols=jnp.array(cols),
        diag_mask=jnp.array(rows == cols),
        sys_size=sys_size,
        ground_indices=jnp.array(ground_idxs),
        is_complex=is_complex,
    )

init ¤

init(operator: Any, options: Any) -> Any

Initialize the solver state (No-op for stateless solvers).

Source code in circulax/solvers/linear.py
def init(self, operator: Any, options: Any) -> Any:  # noqa: ARG002
    """Initialize the solver state (No-op for stateless solvers)."""
    return None

solve_dc ¤

solve_dc(component_groups: dict[str, Any], y_guess: Array) -> Array

Performs a robust DC Operating Point analysis (Newton-Raphson).

This method: 1. Detects if the system is Real or Complex based on self.is_complex. 2. Assembles the system with dt=infinity (to open capacitors). 3. Applies ground constraints (setting specific rows/cols to identity). 4. Solves the linear system J * delta = -Residual. 5. Applies voltage damping to prevent exponential overshoot.

Parameters:

Name Type Description Default
component_groups dict

The circuit components and their parameters.

required
y_guess Array

Initial guess vector (Shape: [N] or [2N]).

required

Returns:

Type Description
Array

jax.Array: The converged solution vector (Flat).

Source code in circulax/solvers/linear.py
def solve_dc(
    self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
    """Performs a robust DC Operating Point analysis (Newton-Raphson).

    This method:
    1.  Detects if the system is Real or Complex based on `self.is_complex`.
    2.  Assembles the system with dt=infinity (to open capacitors).
    3.  Applies ground constraints (setting specific rows/cols to identity).
    4.  Solves the linear system J * delta = -Residual.
    5.  Applies voltage damping to prevent exponential overshoot.

    Args:
        component_groups (dict): The circuit components and their parameters.
        y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).

    Returns:
        jax.Array: The converged solution vector (Flat).

    """

    def dc_step(y: jax.Array, _: Any) -> jax.Array:
        # 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
        if self.is_complex:
            total_f, _, all_vals = assemble_system_complex(
                y, component_groups, t1=0.0, dt=DC_DT
            )
        else:
            total_f, _, all_vals = assemble_system_real(
                y, component_groups, t1=0.0, dt=DC_DT
            )

        # 2. Apply Ground Constraints to Residual
        #    We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
        #    This forces the solver to drive V -> 0.
        total_f_grounded = total_f
        for idx in self.ground_indices:
            total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])

        # 3. Solve Linear System (J * delta = -R)
        sol = self._solve_impl(all_vals, -total_f_grounded)
        delta = sol.value

        # 4. Apply Voltage Limiting (Damping)
        #    Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
        max_change = jnp.max(jnp.abs(delta))
        damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))

        return y + delta * damping

    # 5. Run Newton Loop (Optimistix)
    solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
    sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
    return sol.value

analyze_circuit ¤

analyze_circuit(
    groups: list, num_vars: int, backend: str = "default", *, is_complex: bool = False
) -> CircuitLinearSolver

Initializes a linear solver strategy for circuit analysis.

This function serves as a factory and wrapper to select and configure the appropriate numerical backend for solving the linear system of equations derived from a circuit's topology.

Parameters:

Name Type Description Default
groups list

A list of component groups that define the circuit's structure and properties.

required
num_vars int

The total number of variables in the linear system.

required
backend str

The name of the solver backend to use. Supported backends are 'klu', 'klu_split', 'dense', and 'sparse'. Defaults to 'default', which uses the 'klu' solver.

'default'
is_complex bool

A flag indicating whether the circuit analysis involves complex numbers. Defaults to False.

False

Returns:

Name Type Description
CircuitLinearSolver CircuitLinearSolver

An instance of a circuit linear solver strategy

CircuitLinearSolver

configured for the specified backend and circuit parameters.

Raises:

Type Description
ValueError

If the specified backend is not supported.

Source code in circulax/solvers/linear.py
def analyze_circuit(
    groups: list, num_vars: int, backend: str = "default", *, is_complex: bool = False
) -> CircuitLinearSolver:
    """Initializes a linear solver strategy for circuit analysis.

    This function serves as a factory and wrapper to select and configure the
    appropriate numerical backend for solving the linear system of equations
    derived from a circuit's topology.

    Args:
        groups (list): A list of component groups that define the circuit's
            structure and properties.
        num_vars (int): The total number of variables in the linear system.
        backend (str, optional): The name of the solver backend to use.
            Supported backends are 'klu', 'klu_split', 'dense', and 'sparse'.
            Defaults to 'default', which uses the 'klu' solver.
        is_complex (bool, optional): A flag indicating whether the circuit
            analysis involves complex numbers. Defaults to False.

    Returns:
        CircuitLinearSolver: An instance of a circuit linear solver strategy
        configured for the specified backend and circuit parameters.

    Raises:
        ValueError: If the specified backend is not supported.

    """
    solver_class = backends.get(backend)
    if solver_class is None:
        msg = (
            f"Unknown backend: '{backend}'. "
            f"Available backends are {list(backends.keys())}"
        )
        raise ValueError(
            msg
        )

    linear_strategy = solver_class.from_component_groups(groups, num_vars, is_complex=is_complex)

    return linear_strategy