Circuit Linear Solvers Strategy Pattern.
This module defines the linear algebra strategies used by the circuit simulator.
It leverages the lineax abstract base class to provide interchangeable solvers
that work seamlessly with JAX transformations (JIT, VMAP, GRAD).
Architecture
The core idea is to separate the physics assembly (calculating Jacobian values)
from the linear solve (inverting the Jacobian).
Classes:
| Name |
Description |
CircuitLinearSolver |
Abstract base defining the interface and common DC logic.
|
DenseSolver |
Uses JAX's native dense solver (LU decomposition). Best for small circuits (N < 2000) & GPU.
|
KLUSolver |
Uses the KLU sparse solver (via klujax). Best for large circuits on CPU.
|
SparseSolver |
Uses JAX's iterative BiCGStab. Best for large transient simulations on GPU.
|
Functions:
| Name |
Description |
analyze_circuit |
Initializes a linear solver strategy for circuit analysis.
|
Attributes:
| Name |
Type |
Description |
DAMPING_EPS |
float
|
Small additive epsilon that prevents division by zero in the damping formula.
|
DAMPING_FACTOR |
float
|
Newton-step damping coefficient: limits each step to at most DAMPING_FACTOR / |δy|_max.
|
DC_DT |
float
|
Effective timestep used for DC analysis; makes capacitor stamps vanish (C/dt → 0).
|
GROUND_STIFFNESS |
float
|
Penalty added to ground-node diagonal entries to enforce V=0.
|
DAMPING_EPS
module-attribute
DAMPING_EPS: float = 1e-09
Small additive epsilon that prevents division by zero in the damping formula.
DAMPING_FACTOR
module-attribute
DAMPING_FACTOR: float = 0.5
Newton-step damping coefficient: limits each step to at most DAMPING_FACTOR / |δy|_max.
DC_DT
module-attribute
Effective timestep used for DC analysis; makes capacitor stamps vanish (C/dt → 0).
GROUND_STIFFNESS
module-attribute
GROUND_STIFFNESS: float = 1000000000.0
Penalty added to ground-node diagonal entries to enforce V=0.
CircuitLinearSolver
Bases: AbstractLinearSolver
Abstract Base Class for all circuit linear solvers.
This class provides the unified interface for:
1. Storing static matrix structure (indices, rows, cols).
2. Handling Real vs. Complex-Unrolled system configurations.
3. Providing a robust Newton-Raphson DC Operating Point solver.
Attributes:
| Name |
Type |
Description |
ground_indices |
Array
|
Indices of nodes connected to ground (forced to 0V).
|
is_complex |
bool
|
Static flag. If True, the system is 2N x 2N (Real/Imag unrolled).
If False, the system is N x N (Real).
|
Methods:
| Name |
Description |
assume_full_rank |
Indicate if the solver assumes the operator is full rank.
|
compute |
Performs the computation of the component for each step.
|
init |
Initialize the solver state (No-op for stateless solvers).
|
solve_dc |
Performs a robust DC Operating Point analysis (Newton-Raphson).
|
assume_full_rank
assume_full_rank() -> bool
Indicate if the solver assumes the operator is full rank.
Source code in circulax/solvers/linear.py
| def assume_full_rank(self) -> bool:
"""Indicate if the solver assumes the operator is full rank."""
return False
|
compute
compute(state: Any, vector: Array, options: Any) -> Solution
Performs the computation of the component for each step.
In our case, we usually call _solve_impl directly to avoid overhead,
but this satisfies the API.
Source code in circulax/solvers/linear.py
| def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
"""Performs the computation of the component for each step.
In our case, we usually call `_solve_impl` directly to avoid overhead,
but this satisfies the API.
"""
msg = "Directly call _solve_impl for internal use."
raise NotImplementedError(msg)
|
init
Initialize the solver state (No-op for stateless solvers).
Source code in circulax/solvers/linear.py
| def init(self, operator: Any, options: Any) -> Any: # noqa: ARG002
"""Initialize the solver state (No-op for stateless solvers)."""
return None
|
solve_dc
Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on self.is_complex.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Parameters:
| Name |
Type |
Description |
Default |
component_groups
|
dict
|
The circuit components and their parameters.
|
required
|
y_guess
|
Array
|
Initial guess vector (Shape: [N] or [2N]).
|
required
|
Returns:
| Type |
Description |
Array
|
jax.Array: The converged solution vector (Flat).
|
Source code in circulax/solvers/linear.py
| def solve_dc(
self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
"""Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on `self.is_complex`.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Args:
component_groups (dict): The circuit components and their parameters.
y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).
Returns:
jax.Array: The converged solution vector (Flat).
"""
def dc_step(y: jax.Array, _: Any) -> jax.Array:
# 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
if self.is_complex:
total_f, _, all_vals = assemble_system_complex(
y, component_groups, t1=0.0, dt=DC_DT
)
else:
total_f, _, all_vals = assemble_system_real(
y, component_groups, t1=0.0, dt=DC_DT
)
# 2. Apply Ground Constraints to Residual
# We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
# This forces the solver to drive V -> 0.
total_f_grounded = total_f
for idx in self.ground_indices:
total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])
# 3. Solve Linear System (J * delta = -R)
sol = self._solve_impl(all_vals, -total_f_grounded)
delta = sol.value
# 4. Apply Voltage Limiting (Damping)
# Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
max_change = jnp.max(jnp.abs(delta))
damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))
return y + delta * damping
# 5. Run Newton Loop (Optimistix)
solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
return sol.value
|
DenseSolver
Bases: CircuitLinearSolver
Solves the system using dense matrix factorization (LU).
Best For
- Small to Medium circuits (N < 2000).
- Wavelength sweeps (AC Analysis) on GPU.
- Systems where VMAP parallelism is critical.
Attributes:
| Name |
Type |
Description |
static_rows |
Array
|
Row indices for placing values into dense matrix.
|
static_cols |
Array
|
|
g_leak |
float
|
Leakage conductance added to diagonal to prevent singularity.
|
Methods:
| Name |
Description |
assume_full_rank |
Indicate if the solver assumes the operator is full rank.
|
compute |
Performs the computation of the component for each step.
|
from_component_groups |
Factory method to pre-calculate indices for the dense matrix.
|
init |
Initialize the solver state (No-op for stateless solvers).
|
solve_dc |
Performs a robust DC Operating Point analysis (Newton-Raphson).
|
assume_full_rank
assume_full_rank() -> bool
Indicate if the solver assumes the operator is full rank.
Source code in circulax/solvers/linear.py
| def assume_full_rank(self) -> bool:
"""Indicate if the solver assumes the operator is full rank."""
return False
|
compute
compute(state: Any, vector: Array, options: Any) -> Solution
Performs the computation of the component for each step.
In our case, we usually call _solve_impl directly to avoid overhead,
but this satisfies the API.
Source code in circulax/solvers/linear.py
| def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
"""Performs the computation of the component for each step.
In our case, we usually call `_solve_impl` directly to avoid overhead,
but this satisfies the API.
"""
msg = "Directly call _solve_impl for internal use."
raise NotImplementedError(msg)
|
from_component_groups
classmethod
Factory method to pre-calculate indices for the dense matrix.
Source code in circulax/solvers/linear.py
| @classmethod
def from_component_groups(
cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "DenseSolver":
"""Factory method to pre-calculate indices for the dense matrix."""
rows, cols, ground_idxs, sys_size = _build_index_arrays(
component_groups, num_vars, is_complex
)
return cls(
static_rows=jnp.array(rows),
static_cols=jnp.array(cols),
sys_size=sys_size,
ground_indices=jnp.array(ground_idxs),
is_complex=is_complex,
)
|
init
Initialize the solver state (No-op for stateless solvers).
Source code in circulax/solvers/linear.py
| def init(self, operator: Any, options: Any) -> Any: # noqa: ARG002
"""Initialize the solver state (No-op for stateless solvers)."""
return None
|
solve_dc
Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on self.is_complex.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Parameters:
| Name |
Type |
Description |
Default |
component_groups
|
dict
|
The circuit components and their parameters.
|
required
|
y_guess
|
Array
|
Initial guess vector (Shape: [N] or [2N]).
|
required
|
Returns:
| Type |
Description |
Array
|
jax.Array: The converged solution vector (Flat).
|
Source code in circulax/solvers/linear.py
| def solve_dc(
self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
"""Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on `self.is_complex`.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Args:
component_groups (dict): The circuit components and their parameters.
y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).
Returns:
jax.Array: The converged solution vector (Flat).
"""
def dc_step(y: jax.Array, _: Any) -> jax.Array:
# 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
if self.is_complex:
total_f, _, all_vals = assemble_system_complex(
y, component_groups, t1=0.0, dt=DC_DT
)
else:
total_f, _, all_vals = assemble_system_real(
y, component_groups, t1=0.0, dt=DC_DT
)
# 2. Apply Ground Constraints to Residual
# We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
# This forces the solver to drive V -> 0.
total_f_grounded = total_f
for idx in self.ground_indices:
total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])
# 3. Solve Linear System (J * delta = -R)
sol = self._solve_impl(all_vals, -total_f_grounded)
delta = sol.value
# 4. Apply Voltage Limiting (Damping)
# Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
max_change = jnp.max(jnp.abs(delta))
damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))
return y + delta * damping
# 5. Run Newton Loop (Optimistix)
solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
return sol.value
|
KLUSolver
Bases: CircuitLinearSolver
Solves the system using the KLU sparse solver (via klujax).
Best For
- Large circuits (N > 5000) running on CPU.
- DC Operating Points of massive meshes.
- Cases where DenseSolver runs out of memory (OOM).
Note
Does NOT support vmap (batching) automatically.
Methods:
| Name |
Description |
assume_full_rank |
Indicate if the solver assumes the operator is full rank.
|
compute |
Performs the computation of the component for each step.
|
from_component_groups |
Factory method to pre-hash indices for sparse coalescence.
|
init |
Initialize the solver state (No-op for stateless solvers).
|
solve_dc |
Performs a robust DC Operating Point analysis (Newton-Raphson).
|
assume_full_rank
assume_full_rank() -> bool
Indicate if the solver assumes the operator is full rank.
Source code in circulax/solvers/linear.py
| def assume_full_rank(self) -> bool:
"""Indicate if the solver assumes the operator is full rank."""
return False
|
compute
compute(state: Any, vector: Array, options: Any) -> Solution
Performs the computation of the component for each step.
In our case, we usually call _solve_impl directly to avoid overhead,
but this satisfies the API.
Source code in circulax/solvers/linear.py
| def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
"""Performs the computation of the component for each step.
In our case, we usually call `_solve_impl` directly to avoid overhead,
but this satisfies the API.
"""
msg = "Directly call _solve_impl for internal use."
raise NotImplementedError(msg)
|
from_component_groups
classmethod
Factory method to pre-hash indices for sparse coalescence.
Source code in circulax/solvers/linear.py
| @classmethod
def from_component_groups(
cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "KLUSolver":
"""Factory method to pre-hash indices for sparse coalescence."""
rows, cols, ground_idxs, sys_size = _build_index_arrays(
component_groups, num_vars, is_complex
)
u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
return cls(
u_rows=jnp.array(u_rows),
u_cols=jnp.array(u_cols),
map_idx=jnp.array(map_idx),
n_unique=n_unique,
ground_indices=jnp.array(ground_idxs),
sys_size=sys_size,
is_complex=is_complex,
)
|
init
Initialize the solver state (No-op for stateless solvers).
Source code in circulax/solvers/linear.py
| def init(self, operator: Any, options: Any) -> Any: # noqa: ARG002
"""Initialize the solver state (No-op for stateless solvers)."""
return None
|
solve_dc
Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on self.is_complex.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Parameters:
| Name |
Type |
Description |
Default |
component_groups
|
dict
|
The circuit components and their parameters.
|
required
|
y_guess
|
Array
|
Initial guess vector (Shape: [N] or [2N]).
|
required
|
Returns:
| Type |
Description |
Array
|
jax.Array: The converged solution vector (Flat).
|
Source code in circulax/solvers/linear.py
| def solve_dc(
self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
"""Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on `self.is_complex`.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Args:
component_groups (dict): The circuit components and their parameters.
y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).
Returns:
jax.Array: The converged solution vector (Flat).
"""
def dc_step(y: jax.Array, _: Any) -> jax.Array:
# 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
if self.is_complex:
total_f, _, all_vals = assemble_system_complex(
y, component_groups, t1=0.0, dt=DC_DT
)
else:
total_f, _, all_vals = assemble_system_real(
y, component_groups, t1=0.0, dt=DC_DT
)
# 2. Apply Ground Constraints to Residual
# We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
# This forces the solver to drive V -> 0.
total_f_grounded = total_f
for idx in self.ground_indices:
total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])
# 3. Solve Linear System (J * delta = -R)
sol = self._solve_impl(all_vals, -total_f_grounded)
delta = sol.value
# 4. Apply Voltage Limiting (Damping)
# Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
max_change = jnp.max(jnp.abs(delta))
damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))
return y + delta * damping
# 5. Run Newton Loop (Optimistix)
solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
return sol.value
|
KLUSplitFactorSolver
Bases: KLUSplitSolver
Solves the system using the KLU sparse solver (via klujax) with split interface.
This solver performs symbolic analysis ONCE during initialization and reuses
the symbolic handle for subsequent solves, significantly speeding up non-linear
simulations (Newton-Raphson iterations). This version of the solver is further enhanced but calculting the numeric
part of the KLU solution only once
Best For
- Large circuits (N > 5000) running on CPU.
- DC Operating Points of massive meshes.
Attributes:
| Name |
Type |
Description |
Bp, |
Bi
|
CSC format indices (fixed structure).
|
csc_map_idx |
Bi
|
Mapping from raw value indices to CSC value vector.
|
symbolic_handle |
Bi
|
Pointer to the pre-computed KLU symbolic analysis.
|
Methods:
| Name |
Description |
assume_full_rank |
Indicate if the solver assumes the operator is full rank.
|
compute |
Performs the computation of the component for each step.
|
factor_jacobian |
Factor the Jacobian and return numeric handle.
|
from_component_groups |
Factory method to pre-hash indices for sparse coalescence.
|
init |
Initialize the solver state (No-op for stateless solvers).
|
solve_dc |
Performs a robust DC Operating Point analysis (Newton-Raphson).
|
solve_with_frozen_jacobian |
Solve using pre-computed numeric factorization (for frozen Jacobian Newton).
|
assume_full_rank
assume_full_rank() -> bool
Indicate if the solver assumes the operator is full rank.
Source code in circulax/solvers/linear.py
| def assume_full_rank(self) -> bool:
"""Indicate if the solver assumes the operator is full rank."""
return False
|
compute
compute(state: Any, vector: Array, options: Any) -> Solution
Performs the computation of the component for each step.
In our case, we usually call _solve_impl directly to avoid overhead,
but this satisfies the API.
Source code in circulax/solvers/linear.py
| def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
"""Performs the computation of the component for each step.
In our case, we usually call `_solve_impl` directly to avoid overhead,
but this satisfies the API.
"""
msg = "Directly call _solve_impl for internal use."
raise NotImplementedError(msg)
|
factor_jacobian
Factor the Jacobian and return numeric handle.
Source code in circulax/solvers/linear.py
| def factor_jacobian(self, all_vals: jax.Array) -> jax.Array:
"""Factor the Jacobian and return numeric handle."""
g_vals = jnp.full(self.ground_indices.shape[0], GROUND_STIFFNESS, dtype=all_vals.dtype)
l_vals = jnp.full(self.sys_size, self.g_leak, dtype=all_vals.dtype)
raw_vals = jnp.concatenate([all_vals, g_vals, l_vals])
coalesced_vals = jax.ops.segment_sum(
raw_vals, self.map_idx, num_segments=self.n_unique
)
return klujax.factor(
self.u_rows, self.u_cols, coalesced_vals, self.symbolic_handle
)
|
from_component_groups
classmethod
Factory method to pre-hash indices for sparse coalescence.
Source code in circulax/solvers/linear.py
| @classmethod
def from_component_groups(
cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "KLUSplitSolver":
"""Factory method to pre-hash indices for sparse coalescence."""
rows, cols, ground_idxs, sys_size = _build_index_arrays(
component_groups, num_vars, is_complex
)
u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
symbolic = klujax.analyze(u_rows, u_cols, sys_size)
return cls(
u_rows=jnp.array(u_rows),
u_cols=jnp.array(u_cols),
map_idx=jnp.array(map_idx),
n_unique=n_unique,
_handle_wrapper=symbolic,
ground_indices=jnp.array(ground_idxs),
sys_size=sys_size,
is_complex=is_complex,
)
|
init
Initialize the solver state (No-op for stateless solvers).
Source code in circulax/solvers/linear.py
| def init(self, operator: Any, options: Any) -> Any: # noqa: ARG002
"""Initialize the solver state (No-op for stateless solvers)."""
return None
|
solve_dc
Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on self.is_complex.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Parameters:
| Name |
Type |
Description |
Default |
component_groups
|
dict
|
The circuit components and their parameters.
|
required
|
y_guess
|
Array
|
Initial guess vector (Shape: [N] or [2N]).
|
required
|
Returns:
| Type |
Description |
Array
|
jax.Array: The converged solution vector (Flat).
|
Source code in circulax/solvers/linear.py
| def solve_dc(
self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
"""Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on `self.is_complex`.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Args:
component_groups (dict): The circuit components and their parameters.
y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).
Returns:
jax.Array: The converged solution vector (Flat).
"""
def dc_step(y: jax.Array, _: Any) -> jax.Array:
# 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
if self.is_complex:
total_f, _, all_vals = assemble_system_complex(
y, component_groups, t1=0.0, dt=DC_DT
)
else:
total_f, _, all_vals = assemble_system_real(
y, component_groups, t1=0.0, dt=DC_DT
)
# 2. Apply Ground Constraints to Residual
# We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
# This forces the solver to drive V -> 0.
total_f_grounded = total_f
for idx in self.ground_indices:
total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])
# 3. Solve Linear System (J * delta = -R)
sol = self._solve_impl(all_vals, -total_f_grounded)
delta = sol.value
# 4. Apply Voltage Limiting (Damping)
# Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
max_change = jnp.max(jnp.abs(delta))
damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))
return y + delta * damping
# 5. Run Newton Loop (Optimistix)
solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
return sol.value
|
solve_with_frozen_jacobian
solve_with_frozen_jacobian(residual: Array, numeric: Array) -> Solution
Solve using pre-computed numeric factorization (for frozen Jacobian Newton).
Source code in circulax/solvers/linear.py
| def solve_with_frozen_jacobian(
self, residual: jax.Array, numeric: jax.Array
) -> lx.Solution:
"""Solve using pre-computed numeric factorization (for frozen Jacobian Newton)."""
solution = klujax.solve_with_numeric(
numeric, residual, self._handle_wrapper.handle
)
return lx.Solution(
value=solution.reshape(residual.shape),
result=lx.RESULTS.successful,
state=None,
stats={},
)
|
KLUSplitSolver
Bases: CircuitLinearSolver
Solves the system using the KLU sparse solver (via klujax) with split interface.
This solver performs symbolic analysis ONCE during initialization and reuses
the symbolic handle for subsequent solves, significantly speeding up non-linear
simulations (Newton-Raphson iterations).
Best For
- Large circuits (N > 5000) running on CPU.
- DC Operating Points of massive meshes.
Attributes:
| Name |
Type |
Description |
Bp, |
Bi
|
CSC format indices (fixed structure).
|
csc_map_idx |
Bi
|
Mapping from raw value indices to CSC value vector.
|
symbolic_handle |
Bi
|
Pointer to the pre-computed KLU symbolic analysis.
|
Methods:
| Name |
Description |
assume_full_rank |
Indicate if the solver assumes the operator is full rank.
|
compute |
Performs the computation of the component for each step.
|
from_component_groups |
Factory method to pre-hash indices for sparse coalescence.
|
init |
Initialize the solver state (No-op for stateless solvers).
|
solve_dc |
Performs a robust DC Operating Point analysis (Newton-Raphson).
|
assume_full_rank
assume_full_rank() -> bool
Indicate if the solver assumes the operator is full rank.
Source code in circulax/solvers/linear.py
| def assume_full_rank(self) -> bool:
"""Indicate if the solver assumes the operator is full rank."""
return False
|
compute
compute(state: Any, vector: Array, options: Any) -> Solution
Performs the computation of the component for each step.
In our case, we usually call _solve_impl directly to avoid overhead,
but this satisfies the API.
Source code in circulax/solvers/linear.py
| def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
"""Performs the computation of the component for each step.
In our case, we usually call `_solve_impl` directly to avoid overhead,
but this satisfies the API.
"""
msg = "Directly call _solve_impl for internal use."
raise NotImplementedError(msg)
|
from_component_groups
classmethod
Factory method to pre-hash indices for sparse coalescence.
Source code in circulax/solvers/linear.py
| @classmethod
def from_component_groups(
cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "KLUSplitSolver":
"""Factory method to pre-hash indices for sparse coalescence."""
rows, cols, ground_idxs, sys_size = _build_index_arrays(
component_groups, num_vars, is_complex
)
u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
symbolic = klujax.analyze(u_rows, u_cols, sys_size)
return cls(
u_rows=jnp.array(u_rows),
u_cols=jnp.array(u_cols),
map_idx=jnp.array(map_idx),
n_unique=n_unique,
_handle_wrapper=symbolic,
ground_indices=jnp.array(ground_idxs),
sys_size=sys_size,
is_complex=is_complex,
)
|
init
Initialize the solver state (No-op for stateless solvers).
Source code in circulax/solvers/linear.py
| def init(self, operator: Any, options: Any) -> Any: # noqa: ARG002
"""Initialize the solver state (No-op for stateless solvers)."""
return None
|
solve_dc
Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on self.is_complex.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Parameters:
| Name |
Type |
Description |
Default |
component_groups
|
dict
|
The circuit components and their parameters.
|
required
|
y_guess
|
Array
|
Initial guess vector (Shape: [N] or [2N]).
|
required
|
Returns:
| Type |
Description |
Array
|
jax.Array: The converged solution vector (Flat).
|
Source code in circulax/solvers/linear.py
| def solve_dc(
self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
"""Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on `self.is_complex`.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Args:
component_groups (dict): The circuit components and their parameters.
y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).
Returns:
jax.Array: The converged solution vector (Flat).
"""
def dc_step(y: jax.Array, _: Any) -> jax.Array:
# 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
if self.is_complex:
total_f, _, all_vals = assemble_system_complex(
y, component_groups, t1=0.0, dt=DC_DT
)
else:
total_f, _, all_vals = assemble_system_real(
y, component_groups, t1=0.0, dt=DC_DT
)
# 2. Apply Ground Constraints to Residual
# We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
# This forces the solver to drive V -> 0.
total_f_grounded = total_f
for idx in self.ground_indices:
total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])
# 3. Solve Linear System (J * delta = -R)
sol = self._solve_impl(all_vals, -total_f_grounded)
delta = sol.value
# 4. Apply Voltage Limiting (Damping)
# Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
max_change = jnp.max(jnp.abs(delta))
damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))
return y + delta * damping
# 5. Run Newton Loop (Optimistix)
solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
return sol.value
|
Bases: KLUSplitSolver
Solves the system using the rust wrapped KLU sparse solver (via klu-rs) with split interface.
This solver performs symbolic analysis ONCE during initialization and reuses
the symbolic handle for subsequent solves, significantly speeding up non-linear
simulations (Newton-Raphson iterations).
Best For
- Large circuits (N > 5000) running on CPU.
- DC Operating Points of massive meshes.
Attributes:
| Name |
Type |
Description |
Bp, |
Bi
|
CSC format indices (fixed structure).
|
csc_map_idx |
Bi
|
Mapping from raw value indices to CSC value vector.
|
symbolic_handle |
Bi
|
Pointer to the pre-computed KLU symbolic analysis.
|
Methods:
| Name |
Description |
assume_full_rank |
Indicate if the solver assumes the operator is full rank.
|
compute |
Performs the computation of the component for each step.
|
from_component_groups |
Factory method to pre-hash indices for sparse coalescence.
|
init |
Initialize the solver state (No-op for stateless solvers).
|
solve_dc |
Performs a robust DC Operating Point analysis (Newton-Raphson).
|
assume_full_rank() -> bool
Indicate if the solver assumes the operator is full rank.
Source code in circulax/solvers/linear.py
| def assume_full_rank(self) -> bool:
"""Indicate if the solver assumes the operator is full rank."""
return False
|
compute(state: Any, vector: Array, options: Any) -> Solution
Performs the computation of the component for each step.
In our case, we usually call _solve_impl directly to avoid overhead,
but this satisfies the API.
Source code in circulax/solvers/linear.py
| def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
"""Performs the computation of the component for each step.
In our case, we usually call `_solve_impl` directly to avoid overhead,
but this satisfies the API.
"""
msg = "Directly call _solve_impl for internal use."
raise NotImplementedError(msg)
|
Factory method to pre-hash indices for sparse coalescence.
Source code in circulax/solvers/linear.py
| @classmethod
def from_component_groups(
cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "KlursSplitSolver":
"""Factory method to pre-hash indices for sparse coalescence."""
rows, cols, ground_idxs, sys_size = _build_index_arrays(
component_groups, num_vars, is_complex
)
u_rows, u_cols, map_idx, n_unique = _klu_deduplicate(rows, cols, ground_idxs, sys_size)
symbol = klurs.analyze(u_rows, u_cols, sys_size)
return cls(
u_rows=jnp.array(u_rows),
u_cols=jnp.array(u_cols),
map_idx=jnp.array(map_idx),
n_unique=n_unique,
_handle_wrapper=symbol,
ground_indices=jnp.array(ground_idxs),
sys_size=sys_size,
is_complex=is_complex,
)
|
Initialize the solver state (No-op for stateless solvers).
Source code in circulax/solvers/linear.py
| def init(self, operator: Any, options: Any) -> Any: # noqa: ARG002
"""Initialize the solver state (No-op for stateless solvers)."""
return None
|
Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on self.is_complex.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Parameters:
| Name |
Type |
Description |
Default |
component_groups
|
dict
|
The circuit components and their parameters.
|
required
|
y_guess
|
Array
|
Initial guess vector (Shape: [N] or [2N]).
|
required
|
Returns:
| Type |
Description |
Array
|
jax.Array: The converged solution vector (Flat).
|
Source code in circulax/solvers/linear.py
| def solve_dc(
self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
"""Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on `self.is_complex`.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Args:
component_groups (dict): The circuit components and their parameters.
y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).
Returns:
jax.Array: The converged solution vector (Flat).
"""
def dc_step(y: jax.Array, _: Any) -> jax.Array:
# 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
if self.is_complex:
total_f, _, all_vals = assemble_system_complex(
y, component_groups, t1=0.0, dt=DC_DT
)
else:
total_f, _, all_vals = assemble_system_real(
y, component_groups, t1=0.0, dt=DC_DT
)
# 2. Apply Ground Constraints to Residual
# We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
# This forces the solver to drive V -> 0.
total_f_grounded = total_f
for idx in self.ground_indices:
total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])
# 3. Solve Linear System (J * delta = -R)
sol = self._solve_impl(all_vals, -total_f_grounded)
delta = sol.value
# 4. Apply Voltage Limiting (Damping)
# Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
max_change = jnp.max(jnp.abs(delta))
damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))
return y + delta * damping
# 5. Run Newton Loop (Optimistix)
solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
return sol.value
|
SparseSolver
Bases: CircuitLinearSolver
Solves the system using JAX's Iterative BiCGStab solver.
Best For
- Large Transient Simulations on GPU (uses previous step as warm start).
- Systems where N is too large for Dense, but we need VMAP support.
Attributes:
| Name |
Type |
Description |
diag_mask |
Array
|
Mask to extract diagonal elements for preconditioning.
|
Methods:
| Name |
Description |
assume_full_rank |
Indicate if the solver assumes the operator is full rank.
|
compute |
Performs the computation of the component for each step.
|
from_component_groups |
Factory method to prepare indices and diagonal mask.
|
init |
Initialize the solver state (No-op for stateless solvers).
|
solve_dc |
Performs a robust DC Operating Point analysis (Newton-Raphson).
|
assume_full_rank
assume_full_rank() -> bool
Indicate if the solver assumes the operator is full rank.
Source code in circulax/solvers/linear.py
| def assume_full_rank(self) -> bool:
"""Indicate if the solver assumes the operator is full rank."""
return False
|
compute
compute(state: Any, vector: Array, options: Any) -> Solution
Performs the computation of the component for each step.
In our case, we usually call _solve_impl directly to avoid overhead,
but this satisfies the API.
Source code in circulax/solvers/linear.py
| def compute(self, state: Any, vector: jax.Array, options: Any) -> lx.Solution:
"""Performs the computation of the component for each step.
In our case, we usually call `_solve_impl` directly to avoid overhead,
but this satisfies the API.
"""
msg = "Directly call _solve_impl for internal use."
raise NotImplementedError(msg)
|
from_component_groups
classmethod
Factory method to prepare indices and diagonal mask.
Source code in circulax/solvers/linear.py
| @classmethod
def from_component_groups(
cls, component_groups: dict[str, Any], num_vars: int, *, is_complex: bool = False
) -> "SparseSolver":
"""Factory method to prepare indices and diagonal mask."""
rows, cols, ground_idxs, sys_size = _build_index_arrays(
component_groups, num_vars, is_complex
)
return cls(
static_rows=jnp.array(rows),
static_cols=jnp.array(cols),
diag_mask=jnp.array(rows == cols),
sys_size=sys_size,
ground_indices=jnp.array(ground_idxs),
is_complex=is_complex,
)
|
init
Initialize the solver state (No-op for stateless solvers).
Source code in circulax/solvers/linear.py
| def init(self, operator: Any, options: Any) -> Any: # noqa: ARG002
"""Initialize the solver state (No-op for stateless solvers)."""
return None
|
solve_dc
Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on self.is_complex.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Parameters:
| Name |
Type |
Description |
Default |
component_groups
|
dict
|
The circuit components and their parameters.
|
required
|
y_guess
|
Array
|
Initial guess vector (Shape: [N] or [2N]).
|
required
|
Returns:
| Type |
Description |
Array
|
jax.Array: The converged solution vector (Flat).
|
Source code in circulax/solvers/linear.py
| def solve_dc(
self, component_groups: dict[str, Any], y_guess: jax.Array
) -> jax.Array:
"""Performs a robust DC Operating Point analysis (Newton-Raphson).
This method:
1. Detects if the system is Real or Complex based on `self.is_complex`.
2. Assembles the system with dt=infinity (to open capacitors).
3. Applies ground constraints (setting specific rows/cols to identity).
4. Solves the linear system J * delta = -Residual.
5. Applies voltage damping to prevent exponential overshoot.
Args:
component_groups (dict): The circuit components and their parameters.
y_guess (jax.Array): Initial guess vector (Shape: [N] or [2N]).
Returns:
jax.Array: The converged solution vector (Flat).
"""
def dc_step(y: jax.Array, _: Any) -> jax.Array:
# 1. Assemble System (DC_DT effectively removes time-dependent terms like C*dv/dt)
if self.is_complex:
total_f, _, all_vals = assemble_system_complex(
y, component_groups, t1=0.0, dt=DC_DT
)
else:
total_f, _, all_vals = assemble_system_real(
y, component_groups, t1=0.0, dt=DC_DT
)
# 2. Apply Ground Constraints to Residual
# We add a massive penalty (GROUND_STIFFNESS * V) to the residual at ground nodes.
# This forces the solver to drive V -> 0.
total_f_grounded = total_f
for idx in self.ground_indices:
total_f_grounded = total_f_grounded.at[idx].add(GROUND_STIFFNESS * y[idx])
# 3. Solve Linear System (J * delta = -R)
sol = self._solve_impl(all_vals, -total_f_grounded)
delta = sol.value
# 4. Apply Voltage Limiting (Damping)
# Prevents the solver from taking huge steps that crash exponentials (diodes/transistors).
max_change = jnp.max(jnp.abs(delta))
damping = jnp.minimum(1.0, DAMPING_FACTOR / (max_change + DAMPING_EPS))
return y + delta * damping
# 5. Run Newton Loop (Optimistix)
solver = optx.FixedPointIteration(rtol=1e-6, atol=1e-6)
sol = optx.fixed_point(dc_step, solver, y_guess, max_steps=100, throw=False)
return sol.value
|
analyze_circuit
Initializes a linear solver strategy for circuit analysis.
This function serves as a factory and wrapper to select and configure the
appropriate numerical backend for solving the linear system of equations
derived from a circuit's topology.
Parameters:
| Name |
Type |
Description |
Default |
groups
|
list
|
A list of component groups that define the circuit's
structure and properties.
|
required
|
num_vars
|
int
|
The total number of variables in the linear system.
|
required
|
backend
|
str
|
The name of the solver backend to use.
Supported backends are 'klu', 'klu_split', 'dense', and 'sparse'.
Defaults to 'default', which uses the 'klu' solver.
|
'default'
|
is_complex
|
bool
|
A flag indicating whether the circuit
analysis involves complex numbers. Defaults to False.
|
False
|
Returns:
| Name | Type |
Description |
CircuitLinearSolver |
CircuitLinearSolver
|
An instance of a circuit linear solver strategy
|
|
CircuitLinearSolver
|
configured for the specified backend and circuit parameters.
|
Raises:
| Type |
Description |
ValueError
|
If the specified backend is not supported.
|
Source code in circulax/solvers/linear.py
| def analyze_circuit(
groups: list, num_vars: int, backend: str = "default", *, is_complex: bool = False
) -> CircuitLinearSolver:
"""Initializes a linear solver strategy for circuit analysis.
This function serves as a factory and wrapper to select and configure the
appropriate numerical backend for solving the linear system of equations
derived from a circuit's topology.
Args:
groups (list): A list of component groups that define the circuit's
structure and properties.
num_vars (int): The total number of variables in the linear system.
backend (str, optional): The name of the solver backend to use.
Supported backends are 'klu', 'klu_split', 'dense', and 'sparse'.
Defaults to 'default', which uses the 'klu' solver.
is_complex (bool, optional): A flag indicating whether the circuit
analysis involves complex numbers. Defaults to False.
Returns:
CircuitLinearSolver: An instance of a circuit linear solver strategy
configured for the specified backend and circuit parameters.
Raises:
ValueError: If the specified backend is not supported.
"""
solver_class = backends.get(backend)
if solver_class is None:
msg = (
f"Unknown backend: '{backend}'. "
f"Available backends are {list(backends.keys())}"
)
raise ValueError(
msg
)
linear_strategy = solver_class.from_component_groups(groups, num_vars, is_complex=is_complex)
return linear_strategy
|