Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues with KLU backend when differentiating through circuit models #42

Open
asanders opened this issue Dec 10, 2024 · 1 comment
Open

Comments

@asanders
Copy link

Hi, I'm currently attempting some circuit modelling and was looking at taking gradients through circuit models using SAX. I've found that the default KLU backend raises an error when computing forward gradients, or gives an incorrect result for reverse differentiation, while the FG backend works correctly in both cases. I think these issues may be coming from the klujax.solve and/or klujax.coo_mul_vec functions based on some preliminary testing of klujax.

Here is a minimal example using a single coupler as a circuit and measuring the power on the output ports, looking to determine how output power changes with the coupling parameter:

import sax
from sax.models import coupler
import jax
import jax.numpy as jnp

netlist = {
    "instances": {"coupler0": "coupler"},
    "connections": {},
    "ports": {"in0": "coupler0,in0", "in1": "coupler0,in1", "out0": "coupler0,out0", "out1": "coupler0,out1"}
}
models = {"coupler": coupler}

circuit, _ = sax.circuit(netlist, models, backend='klu', return_type='sdense')

def model(coupling: float = 0.5) -> jax.Array:
    s, ports = circuit(coupling=coupling)  # run the circuit model
    s_fwd = s[ports['in0'], [ports['out0'], ports['out1']]]  # extract the forward scattering parameters
    return jnp.abs(s_fwd) ** 2  # report the power transmission

With backend='fg':

>>> model(0.5))
Array([0.5, 0.5], dtype=float64)
>>> jax.jacfwd(model)(0.5)
Array([-1.,  1.], dtype=float64)
>>> jax.jacrev(model)(0.5)
Array([-1.,  1.], dtype=float64)

With backend='klu':

>>> model(0.5))
Array([0.5, 0.5], dtype=float64)
>>> jax.jacfwd(model)(0.5)
ValueError: Ax should be at most 2D with shape: (n_lhs, n_nz). Got: (1, 1, 8). Note: jax.vmap is supported. Use it if needed.
>>> jax.jacrev(model)(0.5)
Array([0.,  0.], dtype=float64)

Do you know if there is a solution for this?

@flaport
Copy link
Owner

flaport commented Dec 12, 2024

Definitely seems like a bug in klujax to me. I'll have a look if I can find it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants