Photonics Transient¶
In this advanced demonstration, we simulate the non-linear transient response of a photonic circuit. We revisit the standard waveguide model but introduce Two-Photon Absorption (TPA).
In linear optics, loss is constant ($\alpha$). In the non-linear regime, loss becomes intensity-dependent ($\alpha + \beta I$). This creates an Optical Limiting effect: as the optical power increases, the material absorbs more efficiently, effectively "clamping" the output power. This example demonstrates two critical capabilities of the solver:
Dynamic S-Matrices: Unlike the previous linear examples where the system matrix was constant, here the scattering parameters $S(t)$ are a function of the instantaneous state $|E(t)|^2$. The solver re-evaluates the component physics at every femtosecond time-step.
Complex Envelopes: Optical signals oscillate at hundreds of terahertz ($193\text{THz}$). To make simulation feasible, we simulate the complex slowly-varying envelope $A(t)$ rather than the raw field, using a Real-Imaginary Flattening strategy to map the complex state to the solver's real-valued requirements.
What to Expect: A high-power Gaussian pulse will be launched into the waveguide. In the linear regeime, the output is just a scaled, delayed version of the input, however with tpa_coeff > 0, pulse reshaping is observed. The peak of the Gaussian (high intensity) will be flattened or "squashed" due to the non-linear loss, while the tails (low intensity) pass through with standard linear attenuation.
import diffrax
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from circulax.compiler import compile_netlist
from circulax.components.base_component import PhysicsReturn, Signals, States, component
from circulax.components.electronic import Resistor
from circulax.components.photonic import OpticalSourcePulse
from circulax.solvers import analyze_circuit, setup_transient
print("--- DEMO: Photonic Transient (Flat Vector Fix) ---")
jax.config.update("jax_enable_x64", True)
models_map = {
"waveguide": OpticalWaveguide,
"source": OpticalSourcePulse,
"resistor": Resistor,
"ground": lambda: 0,
}
net_dict = {
"instances": {
"GND": {"component": "ground"},
"I1": {"component": "source", "settings": {"power": 100.0, "delay": 0.1e-9}},
"WG1": {
"component": "waveguide",
"settings": {"length_um": 1000.0, "loss_dB_cm": 20.0, "tpa_coeff": 5e-1},
},
"R1": {"component": "resistor", "settings": {"R": 1.0}},
},
"connections": {"GND,p1": ("I1,p2", "R1,p2"), "I1,p1": "WG1,p1", "WG1,p2": "R1,p1"},
}
print("1. Compiling...")
groups, sys_size, port_map = compile_netlist(net_dict, models_map)
linear_strat = analyze_circuit(groups, sys_size, is_complex=True)
print("2. Solving DC Operating Point...")
y_guess_flat = jnp.zeros(sys_size * 2, dtype=jnp.float64)
y_op_flat = linear_strat.solve_dc(groups, y_guess_flat)
print(f" DC Converged. Norm: {jnp.linalg.norm(y_op_flat):.2e}")
transient_sim = setup_transient(groups=groups, linear_strategy=linear_strat)
t_max = 1.0e-9
saveat = diffrax.SaveAt(ts=jnp.linspace(0, t_max, 500))
print("3. Running Transient Simulation...")
sol = transient_sim(
t0=0.0,
t1=t_max,
dt0=1e-13,
y0=y_op_flat,
args=(groups, sys_size),
saveat=saveat,
max_steps=100000,
throw=False,
stepsize_controller=diffrax.PIDController(rtol=1e-4, atol=1e-6),
)
if sol.result == diffrax.RESULTS.successful:
print(" ✅ Simulation Successful")
ts = sol.ts * 1e9
ys_flat = sol.ys
ys_complex = ys_flat[:, :sys_size] + 1j * ys_flat[:, sys_size:]
node_in_idx = port_map["I1,p1"]
node_out_idx = port_map["R1,p1"]
v_in = ys_complex[:, node_in_idx]
v_out = ys_complex[:, node_out_idx]
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 6))
axes = axes.ravel()
axes[0].plot(ts, 20 * jnp.log10(jnp.abs(v_in)), "k--", label="Input Pulse")
axes[0].plot(ts, 20 * jnp.log10(jnp.abs(v_out)), "r-", label="Output (After WG)")
axes[0].set_title("Photonic Transient Response")
axes[0].set_xlabel("Time (ns)")
axes[0].set_ylabel("Power (dBm)")
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[1].set_title("Ouput Power vs Input Power")
axes[1].set_xlabel("Input Power (dB)")
axes[1].set_ylabel("Ouput Power(dB)")
axes[1].plot(
20 * jnp.log10(jnp.abs(v_in)),
(20 * jnp.log10(jnp.abs(v_out))),
"r-",
label="Loss",
)
axes[1].grid(True, alpha=0.3)
else:
print(f"❌ Simulation Failed: {sol.result}")
--- DEMO: Photonic Transient (Flat Vector Fix) --- 1. Compiling... 2. Solving DC Operating Point...
DC Converged. Norm: 1.92e+00 3. Running Transient Simulation...
✅ Simulation Successful